Skip to content

Commit b58ef6d

Browse files
committed
OpenAI api improvements
- Factor out OpenAiApiClientErrorException and OpenAiApiException out of OpenAiApi into standalone common package. - Move the HTTP error handling and header setup into a shared ApiUtils. - Add ChatModel and EmbeddingModel enums to OpenAiApi. - Add ImageModel enum ot OpenAiImageClient.
1 parent 5054718 commit b58ef6d

File tree

9 files changed

+236
-100
lines changed

9 files changed

+236
-100
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
4444
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
4545
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
46+
import org.springframework.ai.openai.api.common.OpenAiApiException;
4647
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
47-
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
4848
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
4949
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
5050
import org.springframework.http.ResponseEntity;

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
import org.springframework.ai.model.ModelOptionsUtils;
3434
import org.springframework.ai.openai.api.OpenAiApi;
3535
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
36-
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
3736
import org.springframework.ai.openai.api.OpenAiApi.Usage;
37+
import org.springframework.ai.openai.api.common.OpenAiApiException;
3838
import org.springframework.retry.RetryCallback;
3939
import org.springframework.retry.RetryContext;
4040
import org.springframework.retry.RetryListener;

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
import org.springframework.ai.image.ImageResponse;
3131
import org.springframework.ai.image.ImageResponseMetadata;
3232
import org.springframework.ai.model.ModelOptionsUtils;
33-
import org.springframework.ai.openai.api.OpenAiApi;
3433
import org.springframework.ai.openai.api.OpenAiImageApi;
34+
import org.springframework.ai.openai.api.common.OpenAiApiException;
3535
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
3636
import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata;
3737
import org.springframework.http.ResponseEntity;
@@ -59,7 +59,7 @@ public class OpenAiImageClient implements ImageClient {
5959

6060
public final RetryTemplate retryTemplate = RetryTemplate.builder()
6161
.maxAttempts(10)
62-
.retryOn(OpenAiApi.OpenAiApiException.class)
62+
.retryOn(OpenAiApiException.class)
6363
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
6464
.withListener(new RetryListener() {
6565
public <T extends Object, E extends Throwable> void onError(RetryContext context,

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 101 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@
1616

1717
package org.springframework.ai.openai.api;
1818

19-
import java.io.IOException;
20-
import java.nio.charset.StandardCharsets;
2119
import java.util.List;
2220
import java.util.Map;
23-
import java.util.function.Consumer;
2421
import java.util.function.Predicate;
2522

2623
import com.fasterxml.jackson.annotation.JsonInclude;
@@ -30,17 +27,12 @@
3027
import reactor.core.publisher.Mono;
3128

3229
import org.springframework.ai.model.ModelOptionsUtils;
30+
import org.springframework.ai.openai.api.common.ApiUtils;
3331
import org.springframework.boot.context.properties.bind.ConstructorBinding;
3432
import org.springframework.core.ParameterizedTypeReference;
35-
import org.springframework.http.HttpHeaders;
36-
import org.springframework.http.MediaType;
3733
import org.springframework.http.ResponseEntity;
38-
import org.springframework.http.client.ClientHttpResponse;
39-
import org.springframework.lang.NonNull;
4034
import org.springframework.util.Assert;
4135
import org.springframework.util.CollectionUtils;
42-
import org.springframework.util.StreamUtils;
43-
import org.springframework.web.client.ResponseErrorHandler;
4436
import org.springframework.web.client.RestClient;
4537
import org.springframework.web.reactive.function.client.WebClient;
4638

@@ -89,72 +81,128 @@ public OpenAiApi(String baseUrl, String openAiToken) {
8981
*/
9082
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
9183

92-
Consumer<HttpHeaders> jsonContentHeaders = headers -> {
93-
headers.setBearerAuth(openAiToken);
94-
headers.setContentType(MediaType.APPLICATION_JSON);
95-
};
96-
97-
var responseErrorHandler = new ResponseErrorHandler() {
98-
99-
@Override
100-
public boolean hasError(@NonNull ClientHttpResponse response) throws IOException {
101-
return response.getStatusCode().isError();
102-
}
103-
104-
@Override
105-
public void handleError(@NonNull ClientHttpResponse response) throws IOException {
106-
if (response.getStatusCode().isError()) {
107-
String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
108-
String message = String.format("%s - %s", response.getStatusCode().value(), error);
109-
if (response.getStatusCode().is4xxClientError()) {
110-
throw new OpenAiApiClientErrorException(message);
111-
}
112-
throw new OpenAiApiException(message);
113-
}
114-
}
115-
};
116-
11784
this.restClient = restClientBuilder
11885
.baseUrl(baseUrl)
119-
.defaultHeaders(jsonContentHeaders)
120-
.defaultStatusHandler(responseErrorHandler)
86+
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
87+
.defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
12188
.build();
12289

12390
this.webClient = WebClient.builder()
12491
.baseUrl(baseUrl)
125-
.defaultHeaders(jsonContentHeaders)
92+
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
12693
.build();
12794
}
12895

129-
13096
/**
131-
* Non HTTP Error related exceptions
97+
* OpenAI Chat Completion Models:
98+
* <a href="https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo">GPT-4 and GPT-4 Turbo</a> and
99+
* <a href="https://platform.openai.com/docs/models/gpt-3-5-turbo">GPT-3.5 Turbo</a>.
132100
*/
133-
public static class OpenAiApiException extends RuntimeException {
101+
enum ChatModel {
102+
/**
103+
* (New) GPT-4 Turbo - latest GPT-4 model intended to reduce cases
104+
* of “laziness” where the model doesn’t complete a task.
105+
* Returns a maximum of 4,096 output tokens.
106+
* Context window: 128k tokens
107+
*/
108+
GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
109+
110+
/**
111+
* Currently points to gpt-4-0125-preview - model featuring improved
112+
* instruction following, JSON mode, reproducible outputs,
113+
* parallel function calling, and more.
114+
* Returns a maximum of 4,096 output tokens
115+
* Context window: 128k tokens
116+
*/
117+
GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"),
118+
119+
/**
120+
* GPT-4 with the ability to understand images, in addition
121+
* to all other GPT-4 Turbo capabilities. Currently points
122+
* to gpt-4-1106-vision-preview.
123+
* Returns a maximum of 4,096 output tokens
124+
* Context window: 128k tokens
125+
*/
126+
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
127+
128+
/**
129+
* Currently points to gpt-4-0613.
130+
* Snapshot of gpt-4 from June 13th 2023 with improved
131+
* function calling support.
132+
* Context window: 8k tokens
133+
*/
134+
GPT_4("gpt-4"),
135+
136+
/**
137+
* Currently points to gpt-4-32k-0613.
138+
* Snapshot of gpt-4-32k from June 13th 2023 with improved
139+
* function calling support.
140+
* Context window: 32k tokens
141+
*/
142+
GPT_4_32K("gpt-4-32k"),
143+
144+
/**
145+
*Currently points to gpt-3.5-turbo-0125.
146+
* model with higher accuracy at responding in requested
147+
* formats and a fix for a bug which caused a text
148+
* encoding issue for non-English language function calls.
149+
* Returns a maximum of 4,096
150+
* Context window: 16k tokens
151+
*/
152+
GPT_3_5_TURBO("gpt-3.5-turbo"),
134153

135-
public OpenAiApiException(String message) {
136-
super(message);
154+
/**
155+
* GPT-3.5 Turbo model with improved instruction following,
156+
* JSON mode, reproducible outputs, parallel function calling,
157+
* and more. Returns a maximum of 4,096 output tokens.
158+
* Context window: 16k tokens.
159+
*/
160+
GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106");
161+
162+
public final String value;
163+
164+
ChatModel(String value) {
165+
this.value = value;
137166
}
138167

139-
public OpenAiApiException(String message, Throwable cause) {
140-
super(message, cause);
168+
public String getValue() {
169+
return value;
141170
}
142171
}
143172

144173
/**
145-
* Thrown on 4xx client errors, such as 401 - Incorrect API key provided,
146-
* 401 - You must be a member of an organization to use the API,
147-
* 429 - Rate limit reached for requests, 429 - You exceeded your current quota
148-
* , please check your plan and billing details.
174+
* OpenAI Embeddings Models:
175+
* <a href="https://platform.openai.com/docs/models/embeddings">Embeddings</a>.
149176
*/
150-
public static class OpenAiApiClientErrorException extends RuntimeException {
177+
enum EmbeddingModel {
178+
179+
/**
180+
* Most capable embedding model for both english and non-english tasks.
181+
* DIMENSION: 3072
182+
*/
183+
TEXT_EMBEDDING_3_LARGE("text-embedding-3-large"),
184+
185+
/**
186+
* Increased performance over 2nd generation ada embedding model.
187+
* DIMENSION: 1536
188+
*/
189+
TEXT_EMBEDDING_3_SMALL("text-embedding-3-small"),
190+
191+
/**
192+
* Most capable 2nd generation embedding model, replacing 16 first
193+
* generation models.
194+
* DIMENSION: 1536
195+
*/
196+
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002");
197+
198+
public final String value;
151199

152-
public OpenAiApiClientErrorException(String message) {
153-
super(message);
200+
EmbeddingModel(String value) {
201+
this.value = value;
154202
}
155203

156-
public OpenAiApiClientErrorException(String message, Throwable cause) {
157-
super(message, cause);
204+
public String getValue() {
205+
return value;
158206
}
159207
}
160208

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,14 @@
1515
*/
1616
package org.springframework.ai.openai.api;
1717

18-
import java.io.IOException;
19-
import java.nio.charset.StandardCharsets;
2018
import java.util.List;
21-
import java.util.function.Consumer;
2219

2320
import com.fasterxml.jackson.annotation.JsonInclude;
2421
import com.fasterxml.jackson.annotation.JsonProperty;
25-
import com.fasterxml.jackson.databind.DeserializationFeature;
26-
import com.fasterxml.jackson.databind.ObjectMapper;
2722

28-
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiClientErrorException;
29-
import org.springframework.ai.openai.api.OpenAiApi.OpenAiApiException;
30-
import org.springframework.http.HttpHeaders;
31-
import org.springframework.http.MediaType;
23+
import org.springframework.ai.openai.api.common.ApiUtils;
3224
import org.springframework.http.ResponseEntity;
33-
import org.springframework.http.client.ClientHttpResponse;
3425
import org.springframework.util.Assert;
35-
import org.springframework.util.StreamUtils;
36-
import org.springframework.web.client.ResponseErrorHandler;
3726
import org.springframework.web.client.RestClient;
3827

3928
/**
@@ -50,8 +39,6 @@ public class OpenAiImageApi {
5039

5140
private final RestClient restClient;
5241

53-
private final ObjectMapper objectMapper;
54-
5542
/**
5643
* Create a new OpenAI Image api with base URL set to https://api.openai.com
5744
* @param openAiToken OpenAI apiKey.
@@ -62,39 +49,42 @@ public OpenAiImageApi(String openAiToken) {
6249

6350
public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
6451

65-
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
66-
67-
Consumer<HttpHeaders> jsonContentHeaders = headers -> {
68-
headers.setBearerAuth(openAiToken);
69-
headers.setContentType(MediaType.APPLICATION_JSON);
70-
};
71-
72-
var responseErrorHandler = new ResponseErrorHandler() {
73-
74-
@Override
75-
public boolean hasError(ClientHttpResponse response) throws IOException {
76-
return response.getStatusCode().isError();
77-
}
78-
79-
@Override
80-
public void handleError(ClientHttpResponse response) throws IOException {
81-
if (response.getStatusCode().isError()) {
82-
String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
83-
String message = String.format("%s - %s", response.getStatusCode().value(), error);
84-
if (response.getStatusCode().is4xxClientError()) {
85-
throw new OpenAiApiClientErrorException(message);
86-
}
87-
throw new OpenAiApiException(message);
88-
}
89-
}
90-
};
91-
9252
this.restClient = restClientBuilder.baseUrl(baseUrl)
93-
.defaultHeaders(jsonContentHeaders)
94-
.defaultStatusHandler(responseErrorHandler)
53+
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
54+
.defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
9555
.build();
9656
}
9757

58+
/**
59+
* OpenAI Image API model.
60+
* <a href="https://platform.openai.com/docs/models/dall-e">DALL·E</a>
61+
*/
62+
enum ImageModel {
63+
64+
/**
65+
* The latest DALL·E model released in Nov 2023.
66+
*/
67+
DALL_E_3("dall-e-3"),
68+
69+
/**
70+
* The previous DALL·E model released in Nov 2022. The 2nd iteration of DALL·E
71+
* with more realistic, accurate, and 4x greater resolution images than the
72+
* original model.
73+
*/
74+
DALL_E_2("dall-e-2");
75+
76+
private final String model;
77+
78+
ImageModel(String model) {
79+
this.model = model;
80+
}
81+
82+
public String model() {
83+
return this.model;
84+
}
85+
86+
}
87+
9888
// @formatter:off
9989
@JsonInclude(JsonInclude.Include.NON_NULL)
10090
public record OpenAiImageRequest (

0 commit comments

Comments
 (0)