Skip to content

Commit 8f6f9b0

Browse files
Adding support for transcriptions
1 parent 2704d53 commit 8f6f9b0

File tree

24 files changed

+1327
-4
lines changed

24 files changed

+1327
-4
lines changed
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package org.springframework.ai.openai;/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
import org.slf4j.Logger;
18+
import org.slf4j.LoggerFactory;
19+
import org.springframework.ai.chat.ChatOptions;
20+
import org.springframework.ai.chat.metadata.RateLimit;
21+
import org.springframework.ai.model.ModelOptionsUtils;
22+
import org.springframework.ai.openai.api.OpenAiApi;
23+
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
24+
import org.springframework.ai.openai.metadata.OpenAiTranscriptionResponseMetadata;
25+
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
26+
import org.springframework.ai.transcription.*;
27+
import org.springframework.core.io.Resource;
28+
import org.springframework.http.ResponseEntity;
29+
import org.springframework.retry.support.RetryTemplate;
30+
import org.springframework.util.Assert;
31+
import org.springframework.util.LinkedMultiValueMap;
32+
import org.springframework.util.MultiValueMap;
33+
34+
import java.time.Duration;
35+
import java.util.List;
36+
37+
/**
38+
* {@link TranscriptionClient} implementation for {@literal OpenAI} backed by
39+
* {@link OpenAiApi}.
40+
*
41+
* @author Michael Lavelle
42+
* @see TranscriptionClient
43+
* @see OpenAiApi
44+
*/
45+
public class OpenAiTranscriptionClient implements TranscriptionClient {
46+
47+
private final Logger logger = LoggerFactory.getLogger(getClass());
48+
49+
private OpenAiTranscriptionOptions defaultOptions = OpenAiTranscriptionOptions.builder()
50+
.withModel("whisper-1")
51+
.withTemperature(0.7f)
52+
.build();
53+
54+
public final RetryTemplate retryTemplate = RetryTemplate.builder()
55+
.maxAttempts(10)
56+
.retryOn(OpenAiApiException.class)
57+
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
58+
.build();
59+
60+
private final OpenAiApi openAiApi;
61+
62+
public OpenAiTranscriptionClient(OpenAiApi openAiApi) {
63+
Assert.notNull(openAiApi, "OpenAiApi must not be null");
64+
this.openAiApi = openAiApi;
65+
}
66+
67+
public OpenAiTranscriptionClient withDefaultOptions(OpenAiTranscriptionOptions options) {
68+
this.defaultOptions = options;
69+
return this;
70+
}
71+
72+
@Override
73+
public TranscriptionResponse call(TranscriptionRequest request) {
74+
75+
return this.retryTemplate.execute(ctx -> {
76+
Resource audioResource = request.getInstructions();
77+
78+
MultiValueMap<String, Object> requestBody = createRequestBody(request);
79+
80+
boolean jsonResponse = !requestBody.containsKey("response_format")
81+
|| requestBody.get("response_format").contains("json");
82+
83+
if (jsonResponse) {
84+
85+
ResponseEntity<OpenAiApi.Transcription> transcriptionEntity = this.openAiApi
86+
.transcriptionEntityJson(requestBody);
87+
88+
var transcription = transcriptionEntity.getBody();
89+
90+
if (transcription == null) {
91+
logger.warn("No transcription returned for request: {}", audioResource);
92+
return new TranscriptionResponse(null);
93+
}
94+
95+
Transcript transcript = new Transcript(transcription.text());
96+
97+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
98+
99+
return new TranscriptionResponse(transcript,
100+
OpenAiTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
101+
.withRateLimit(rateLimits));
102+
103+
}
104+
else {
105+
ResponseEntity<String> transcriptionEntity = this.openAiApi.transcriptionEntityText(requestBody);
106+
107+
var transcription = transcriptionEntity.getBody();
108+
109+
if (transcription == null) {
110+
logger.warn("No transcription returned for request: {}", audioResource);
111+
return new TranscriptionResponse(null);
112+
}
113+
114+
Transcript transcript = new Transcript(transcription);
115+
116+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
117+
118+
return new TranscriptionResponse(transcript,
119+
OpenAiTranscriptionResponseMetadata.from(transcriptionEntity.getBody())
120+
.withRateLimit(rateLimits));
121+
122+
}
123+
124+
});
125+
}
126+
127+
private MultiValueMap<String, Object> createRequestBody(TranscriptionRequest transcriptionRequest) {
128+
129+
OpenAiApi.TranscriptionRequest request = new OpenAiApi.TranscriptionRequest();
130+
131+
if (this.defaultOptions != null) {
132+
request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.TranscriptionRequest.class);
133+
}
134+
135+
if (transcriptionRequest.getOptions() != null) {
136+
if (transcriptionRequest.getOptions() instanceof TranscriptionOptions runtimeOptions) {
137+
OpenAiTranscriptionOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
138+
TranscriptionOptions.class, OpenAiTranscriptionOptions.class);
139+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.TranscriptionRequest.class);
140+
}
141+
else {
142+
throw new IllegalArgumentException("Prompt options are not of type TranscriptionOptions: "
143+
+ transcriptionRequest.getOptions().getClass().getSimpleName());
144+
}
145+
}
146+
MultiValueMap<String, Object> requestBody = new LinkedMultiValueMap<>();
147+
if (request.responseFormat() != null) {
148+
requestBody.add("response_format", request.responseFormat().type());
149+
}
150+
if (request.prompt() != null) {
151+
requestBody.add("prompt", request.prompt());
152+
}
153+
if (request.temperature() != null) {
154+
requestBody.add("temperature", request.temperature());
155+
}
156+
if (request.language() != null) {
157+
requestBody.add("language", request.language());
158+
}
159+
if (request.model() != null) {
160+
requestBody.add("model", request.model());
161+
}
162+
if (transcriptionRequest.getInstructions() != null) {
163+
requestBody.add("file", transcriptionRequest.getInstructions());
164+
}
165+
return requestBody;
166+
}
167+
168+
}
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import com.fasterxml.jackson.annotation.JsonInclude;
20+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
21+
import com.fasterxml.jackson.annotation.JsonProperty;
22+
import org.springframework.ai.openai.api.OpenAiApi.TranscriptionRequest.ResponseFormat;
23+
import org.springframework.ai.transcription.TranscriptionOptions;
24+
25+
/**
26+
* @author Michael Lavelle
27+
*/
28+
@JsonInclude(Include.NON_NULL)
29+
public class OpenAiTranscriptionOptions implements TranscriptionOptions {
30+
31+
// @formatter:off
32+
/**
33+
* ID of the model to use.
34+
*/
35+
private @JsonProperty("model") String model;
36+
37+
/**
38+
* An object specifying the format that the model must output. Setting to { "type":
39+
* "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
40+
*/
41+
private @JsonProperty("response_format") ResponseFormat responseFormat;
42+
43+
private @JsonProperty("prompt") String prompt;
44+
45+
private @JsonProperty("language") String language;
46+
47+
/**
48+
* What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
49+
* more random, while lower values like 0.2 will make it more focused and deterministic.
50+
*/
51+
private @JsonProperty("temperature") Float temperature = 0.8f;
52+
53+
54+
public static Builder builder() {
55+
return new Builder();
56+
}
57+
58+
public static class Builder {
59+
60+
protected OpenAiTranscriptionOptions options;
61+
62+
public Builder() {
63+
this.options = new OpenAiTranscriptionOptions();
64+
}
65+
66+
public Builder(OpenAiTranscriptionOptions options) {
67+
this.options = options;
68+
}
69+
70+
public Builder withModel(String model) {
71+
this.options.model = model;
72+
return this;
73+
}
74+
75+
public Builder withLanguage(String language) {
76+
this.options.language = language;
77+
return this;
78+
}
79+
80+
public Builder withPrompt(String prompt) {
81+
this.options.prompt = prompt;
82+
return this;
83+
}
84+
85+
public Builder withResponseFormat(ResponseFormat responseFormat) {
86+
this.options.responseFormat = responseFormat;
87+
return this;
88+
}
89+
90+
public Builder withTemperature(Float temperature) {
91+
this.options.temperature = temperature;
92+
return this;
93+
}
94+
95+
public OpenAiTranscriptionOptions build() {
96+
return this.options;
97+
}
98+
99+
}
100+
101+
public String getModel() {
102+
return this.model;
103+
}
104+
105+
public void setModel(String model) {
106+
this.model = model;
107+
}
108+
109+
public String getLanguage() {
110+
return this.language;
111+
}
112+
113+
public void setLanguage(String language) {
114+
this.language = language;
115+
}
116+
117+
public String getPrompt() {
118+
return this.prompt;
119+
}
120+
121+
public void setPrompt(String prompt) {
122+
this.prompt = prompt;
123+
}
124+
125+
public Float getTemperature() {
126+
return this.temperature;
127+
}
128+
129+
public void setTemperature(Float temperature) {
130+
this.temperature = temperature;
131+
}
132+
133+
134+
public ResponseFormat getResponseFormat() {
135+
return this.responseFormat;
136+
}
137+
138+
public void setResponseFormat(ResponseFormat responseFormat) {
139+
this.responseFormat = responseFormat;
140+
}
141+
142+
143+
144+
@Override
145+
public int hashCode() {
146+
final int prime = 31;
147+
int result = 1;
148+
result = prime * result + ((model == null) ? 0 : model.hashCode());
149+
result = prime * result + ((prompt == null) ? 0 : prompt.hashCode());
150+
result = prime * result + ((language == null) ? 0 : language.hashCode());
151+
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
152+
return result;
153+
}
154+
155+
@Override
156+
public boolean equals(Object obj) {
157+
if (this == obj)
158+
return true;
159+
if (obj == null)
160+
return false;
161+
if (getClass() != obj.getClass())
162+
return false;
163+
OpenAiTranscriptionOptions other = (OpenAiTranscriptionOptions) obj;
164+
if (this.model == null) {
165+
if (other.model != null)
166+
return false;
167+
}
168+
else if (!model.equals(other.model))
169+
return false;
170+
if (this.prompt == null) {
171+
if (other.prompt != null)
172+
return false;
173+
}
174+
else if (!this.prompt.equals(other.prompt))
175+
return false;
176+
if (this.language == null) {
177+
if (other.language != null)
178+
return false;
179+
}
180+
else if (!this.language.equals(other.language))
181+
return false;
182+
if (this.responseFormat == null) {
183+
if (other.responseFormat != null)
184+
return false;
185+
}
186+
else if (!this.responseFormat.equals(other.responseFormat))
187+
return false;
188+
return true;
189+
}
190+
}

0 commit comments

Comments
 (0)