Skip to content

Commit 8acccad

Browse files
Adding support for transcriptions
1 parent 2704d53 commit 8acccad

File tree

24 files changed

+1256
-4
lines changed

24 files changed

+1256
-4
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package org.springframework.ai.openai;/*
2+
* Copyright 2023-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
import org.slf4j.Logger;
18+
import org.slf4j.LoggerFactory;
19+
import org.springframework.ai.chat.ChatOptions;
20+
import org.springframework.ai.chat.metadata.RateLimit;
21+
import org.springframework.ai.model.ModelOptionsUtils;
22+
import org.springframework.ai.openai.api.OpenAiApi;
23+
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
24+
import org.springframework.ai.openai.metadata.OpenAiTranscriptionResponseMetadata;
25+
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
26+
import org.springframework.ai.transcription.*;
27+
import org.springframework.core.io.Resource;
28+
import org.springframework.http.ResponseEntity;
29+
import org.springframework.retry.support.RetryTemplate;
30+
import org.springframework.util.Assert;
31+
import org.springframework.util.LinkedMultiValueMap;
32+
import org.springframework.util.MultiValueMap;
33+
34+
import java.time.Duration;
35+
import java.util.List;
36+
37+
/**
38+
* {@link TranscriptionClient} implementation for {@literal OpenAI} backed by
39+
* {@link OpenAiApi}.
40+
*
41+
* @author Michael Lavelle
42+
* @see TranscriptionClient
43+
* @see OpenAiApi
44+
*/
45+
public class OpenAiTranscriptionClient implements TranscriptionClient {
46+
47+
private final Logger logger = LoggerFactory.getLogger(getClass());
48+
49+
private OpenAiTranscriptionOptions defaultOptions = OpenAiTranscriptionOptions.builder()
50+
.withModel("whisper-1")
51+
.withTemperature(0.7f)
52+
.build();
53+
54+
public final RetryTemplate retryTemplate = RetryTemplate.builder()
55+
.maxAttempts(10)
56+
.retryOn(OpenAiApiException.class)
57+
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
58+
.build();
59+
60+
private final OpenAiApi openAiApi;
61+
62+
public OpenAiTranscriptionClient(OpenAiApi openAiApi) {
63+
Assert.notNull(openAiApi, "OpenAiApi must not be null");
64+
this.openAiApi = openAiApi;
65+
}
66+
67+
public OpenAiTranscriptionClient withDefaultOptions(OpenAiTranscriptionOptions options) {
68+
this.defaultOptions = options;
69+
return this;
70+
}
71+
72+
@Override
73+
public TranscriptionResponse call(TranscriptionRequest request) {
74+
75+
return this.retryTemplate.execute(ctx -> {
76+
Resource audioResource = request.getInstructions();
77+
78+
MultiValueMap<String, Object> reqyestBody = createRequestBody(request);
79+
80+
ResponseEntity<OpenAiApi.Transcription> transcriptionEntity = this.openAiApi
81+
.transcriptionEntity(reqyestBody);
82+
83+
var transcription = transcriptionEntity.getBody();
84+
85+
if (transcription == null) {
86+
logger.warn("No transcription returned for request: {}", audioResource);
87+
return new TranscriptionResponse(null);
88+
}
89+
90+
Transcript transcript = new Transcript(transcription.text());
91+
92+
RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(transcriptionEntity);
93+
94+
return new TranscriptionResponse(transcript,
95+
OpenAiTranscriptionResponseMetadata.from(transcriptionEntity.getBody()).withRateLimit(rateLimits));
96+
});
97+
}
98+
99+
private MultiValueMap<String, Object> createRequestBody(TranscriptionRequest transcriptionRequest) {
100+
101+
OpenAiApi.TranscriptionRequest request = new OpenAiApi.TranscriptionRequest();
102+
103+
if (this.defaultOptions != null) {
104+
request = ModelOptionsUtils.merge(request, this.defaultOptions, OpenAiApi.TranscriptionRequest.class);
105+
}
106+
107+
if (transcriptionRequest.getOptions() != null) {
108+
if (transcriptionRequest.getOptions() instanceof TranscriptionOptions runtimeOptions) {
109+
OpenAiTranscriptionOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
110+
TranscriptionOptions.class, OpenAiTranscriptionOptions.class);
111+
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, OpenAiApi.TranscriptionRequest.class);
112+
}
113+
else {
114+
throw new IllegalArgumentException("Prompt options are not of type TranscriptionOptions: "
115+
+ transcriptionRequest.getOptions().getClass().getSimpleName());
116+
}
117+
}
118+
MultiValueMap<String, Object> requestBody = new LinkedMultiValueMap<>();
119+
if (request.responseFormat() != null) {
120+
requestBody.add("response_format", request.responseFormat().type());
121+
}
122+
if (request.prompt() != null) {
123+
requestBody.add("prompt", request.prompt());
124+
}
125+
if (request.temperature() != null) {
126+
requestBody.add("temperature", request.temperature());
127+
}
128+
if (request.language() != null) {
129+
requestBody.add("language", request.language());
130+
}
131+
if (request.model() != null) {
132+
requestBody.add("model", request.model());
133+
}
134+
if (transcriptionRequest.getInstructions() != null) {
135+
requestBody.add("file", transcriptionRequest.getInstructions());
136+
}
137+
return requestBody;
138+
}
139+
140+
}
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import com.fasterxml.jackson.annotation.JsonIgnore;
20+
import com.fasterxml.jackson.annotation.JsonInclude;
21+
import com.fasterxml.jackson.annotation.JsonInclude.Include;
22+
import com.fasterxml.jackson.annotation.JsonProperty;
23+
import org.springframework.ai.chat.ChatOptions;
24+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
25+
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoice;
26+
import org.springframework.ai.openai.api.OpenAiApi.FunctionTool;
27+
import org.springframework.ai.transcription.TranscriptionOptions;
28+
29+
import java.util.List;
30+
import java.util.Map;
31+
32+
/**
33+
* @author Michael Lavelle
34+
*/
35+
@JsonInclude(Include.NON_NULL)
36+
public class OpenAiTranscriptionOptions implements TranscriptionOptions {
37+
38+
// @formatter:off
39+
/**
40+
* ID of the model to use.
41+
*/
42+
private @JsonProperty("model") String model;
43+
44+
/**
45+
* An object specifying the format that the model must output. Setting to { "type":
46+
* "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
47+
*/
48+
private @JsonProperty("response_format") ResponseFormat responseFormat;
49+
50+
private @JsonProperty("prompt") String prompt;
51+
52+
private @JsonProperty("language") String language;
53+
54+
/**
55+
* What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output
56+
* more random, while lower values like 0.2 will make it more focused and deterministic.
57+
*/
58+
private @JsonProperty("temperature") Float temperature = 0.8f;
59+
60+
61+
public static Builder builder() {
62+
return new Builder();
63+
}
64+
65+
public static class Builder {
66+
67+
protected OpenAiTranscriptionOptions options;
68+
69+
public Builder() {
70+
this.options = new OpenAiTranscriptionOptions();
71+
}
72+
73+
public Builder(OpenAiTranscriptionOptions options) {
74+
this.options = options;
75+
}
76+
77+
public Builder withModel(String model) {
78+
this.options.model = model;
79+
return this;
80+
}
81+
82+
public Builder withLanguage(String language) {
83+
this.options.language = language;
84+
return this;
85+
}
86+
87+
public Builder withPrompt(String prompt) {
88+
this.options.prompt = prompt;
89+
return this;
90+
}
91+
92+
public Builder withResponseFormat(ResponseFormat responseFormat) {
93+
this.options.responseFormat = responseFormat;
94+
return this;
95+
}
96+
97+
public Builder withTemperature(Float temperature) {
98+
this.options.temperature = temperature;
99+
return this;
100+
}
101+
102+
public OpenAiTranscriptionOptions build() {
103+
return this.options;
104+
}
105+
106+
}
107+
108+
public String getModel() {
109+
return this.model;
110+
}
111+
112+
public void setModel(String model) {
113+
this.model = model;
114+
}
115+
116+
public String getLanguage() {
117+
return this.language;
118+
}
119+
120+
public void setLanguage(String language) {
121+
this.language = language;
122+
}
123+
124+
public String getPrompt() {
125+
return this.prompt;
126+
}
127+
128+
public void setPrompt(String prompt) {
129+
this.prompt = prompt;
130+
}
131+
132+
public Float getTemperature() {
133+
return this.temperature;
134+
}
135+
136+
public void setTemperature(Float temperature) {
137+
this.temperature = temperature;
138+
}
139+
140+
141+
public ResponseFormat getResponseFormat() {
142+
return this.responseFormat;
143+
}
144+
145+
public void setResponseFormat(ResponseFormat responseFormat) {
146+
this.responseFormat = responseFormat;
147+
}
148+
149+
150+
151+
@Override
152+
public int hashCode() {
153+
final int prime = 31;
154+
int result = 1;
155+
result = prime * result + ((model == null) ? 0 : model.hashCode());
156+
result = prime * result + ((prompt == null) ? 0 : prompt.hashCode());
157+
result = prime * result + ((language == null) ? 0 : language.hashCode());
158+
result = prime * result + ((responseFormat == null) ? 0 : responseFormat.hashCode());
159+
return result;
160+
}
161+
162+
@Override
163+
public boolean equals(Object obj) {
164+
if (this == obj)
165+
return true;
166+
if (obj == null)
167+
return false;
168+
if (getClass() != obj.getClass())
169+
return false;
170+
OpenAiTranscriptionOptions other = (OpenAiTranscriptionOptions) obj;
171+
if (this.model == null) {
172+
if (other.model != null)
173+
return false;
174+
}
175+
else if (!model.equals(other.model))
176+
return false;
177+
if (this.prompt == null) {
178+
if (other.prompt != null)
179+
return false;
180+
}
181+
else if (!this.prompt.equals(other.prompt))
182+
return false;
183+
if (this.language == null) {
184+
if (other.language != null)
185+
return false;
186+
}
187+
else if (!this.language.equals(other.language))
188+
return false;
189+
if (this.responseFormat == null) {
190+
if (other.responseFormat != null)
191+
return false;
192+
}
193+
else if (!this.responseFormat.equals(other.responseFormat))
194+
return false;
195+
return true;
196+
}
197+
}

0 commit comments

Comments
 (0)