Skip to content

Commit 1e3eaec

Browse files
tzolovmarkpollack
authored andcommitted
Refactor and centralize Retry logic:
- Establish a new "spring-ai-retry" project, implementing a default HTTP error handler, RetryTemplate, and handling both Transient and Non-Transient Exceptions. - Streamline existing clients (e.g., OpenAI and MistralAI) to utilize "spring-ai-retry." - Integrate retry auto-configuration with customizable properties, extending it to OpenAI and MistralAI Auto-Configs. - Allow configuration of RetryTemplate and ResponseErrorHandler for various clients, including OpenAIChatClient, OpenAiEmbeddingClient, OpenAiAudioTranscriptionCline, OpenAiImageClient, MistralAiChatClient, and MistralAiEmbeddingClient. - Add tests for default RestTemplate and ResponseErrorHandler configurations in OpenAI and MistralAI. - Introduce new retry auto-config properties: "onClientErrors" and "onHttpCodes". - Implement tests for retry auto-config properties. - Generate missing license headers.
1 parent 78f73d1 commit 1e3eaec

File tree

57 files changed

+1411
-385
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1411
-385
lines changed

models/spring-ai-mistral-ai/pom.xml

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,13 @@
2929
<version>${project.parent.version}</version>
3030
</dependency>
3131

32-
<dependency>
33-
<groupId>org.springframework</groupId>
34-
<artifactId>spring-web</artifactId>
35-
<version>${spring-framework.version}</version>
36-
</dependency>
37-
38-
<dependency>
39-
<groupId>org.springframework.retry</groupId>
40-
<artifactId>spring-retry</artifactId>
41-
<version>2.0.4</version>
42-
</dependency>
43-
32+
<dependency>
33+
<groupId>org.springframework.ai</groupId>
34+
<artifactId>spring-ai-retry</artifactId>
35+
<version>${project.parent.version}</version>
36+
</dependency>
4437

4538
<!-- Spring Framework -->
46-
<dependency>
47-
<groupId>org.springframework</groupId>
48-
<artifactId>spring-webflux</artifactId>
49-
</dependency>
5039
<dependency>
5140
<groupId>org.springframework</groupId>
5241
<artifactId>spring-context-support</artifactId>

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616
package org.springframework.ai.mistralai;
1717

18-
import java.time.Duration;
1918
import java.util.HashSet;
2019
import java.util.List;
2120
import java.util.Map;
@@ -41,10 +40,8 @@
4140
import org.springframework.ai.model.ModelOptionsUtils;
4241
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
4342
import org.springframework.ai.model.function.FunctionCallbackContext;
43+
import org.springframework.ai.retry.RetryUtils;
4444
import org.springframework.http.ResponseEntity;
45-
import org.springframework.retry.RetryCallback;
46-
import org.springframework.retry.RetryContext;
47-
import org.springframework.retry.RetryListener;
4845
import org.springframework.retry.support.RetryTemplate;
4946
import org.springframework.util.Assert;
5047
import org.springframework.util.CollectionUtils;
@@ -70,17 +67,7 @@ public class MistralAiChatClient extends
7067
*/
7168
private final MistralAiApi mistralAiApi;
7269

73-
private final RetryTemplate retryTemplate = RetryTemplate.builder()
74-
.maxAttempts(10)
75-
.retryOn(MistralAiApi.MistralAiApiException.class)
76-
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
77-
.withListener(new RetryListener() {
78-
public <T extends Object, E extends Throwable> void onError(RetryContext context,
79-
RetryCallback<T, E> callback, Throwable throwable) {
80-
log.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
81-
};
82-
})
83-
.build();
70+
private final RetryTemplate retryTemplate;
8471

8572
public MistralAiChatClient(MistralAiApi mistralAiApi) {
8673
this(mistralAiApi,
@@ -93,46 +80,50 @@ public MistralAiChatClient(MistralAiApi mistralAiApi) {
9380
}
9481

9582
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
96-
this(mistralAiApi, options, null);
83+
this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
9784
}
9885

9986
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options,
100-
FunctionCallbackContext functionCallbackContext) {
87+
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
10188
super(functionCallbackContext);
10289
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
10390
Assert.notNull(options, "Options must not be null");
91+
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
10492
this.mistralAiApi = mistralAiApi;
10593
this.defaultOptions = options;
94+
this.retryTemplate = retryTemplate;
10695
}
10796

10897
@Override
10998
public ChatResponse call(Prompt prompt) {
110-
// return retryTemplate.execute(ctx -> {
11199
var request = createRequest(prompt, false);
112100

113-
// var completionEntity = this.mistralAiApi.chatCompletionEntity(request);
114-
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
101+
return retryTemplate.execute(ctx -> {
115102

116-
var chatCompletion = completionEntity.getBody();
117-
if (chatCompletion == null) {
118-
log.warn("No chat completion returned for prompt: {}", prompt);
119-
return new ChatResponse(List.of());
120-
}
103+
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
121104

122-
List<Generation> generations = chatCompletion.choices()
123-
.stream()
124-
.map(choice -> new Generation(choice.message().content(), Map.of("role", choice.message().role().name()))
125-
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
126-
.toList();
105+
var chatCompletion = completionEntity.getBody();
106+
if (chatCompletion == null) {
107+
log.warn("No chat completion returned for prompt: {}", prompt);
108+
return new ChatResponse(List.of());
109+
}
110+
111+
List<Generation> generations = chatCompletion.choices()
112+
.stream()
113+
.map(choice -> new Generation(choice.message().content(),
114+
Map.of("role", choice.message().role().name()))
115+
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
116+
.toList();
127117

128-
return new ChatResponse(generations);
129-
// });
118+
return new ChatResponse(generations);
119+
});
130120
}
131121

132122
@Override
133123
public Flux<ChatResponse> stream(Prompt prompt) {
124+
var request = createRequest(prompt, true);
125+
134126
return retryTemplate.execute(ctx -> {
135-
var request = createRequest(prompt, true);
136127

137128
var completionChunks = this.mistralAiApi.chatCompletionStream(request);
138129

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,25 @@
1515
*/
1616
package org.springframework.ai.mistralai;
1717

18+
import java.util.List;
19+
1820
import org.slf4j.Logger;
1921
import org.slf4j.LoggerFactory;
22+
2023
import org.springframework.ai.document.Document;
2124
import org.springframework.ai.document.MetadataMode;
22-
import org.springframework.ai.embedding.*;
25+
import org.springframework.ai.embedding.AbstractEmbeddingClient;
26+
import org.springframework.ai.embedding.Embedding;
27+
import org.springframework.ai.embedding.EmbeddingOptions;
28+
import org.springframework.ai.embedding.EmbeddingRequest;
29+
import org.springframework.ai.embedding.EmbeddingResponse;
30+
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
2331
import org.springframework.ai.mistralai.api.MistralAiApi;
24-
import org.springframework.ai.mistralai.api.MistralAiApi.MistralAiApiException;
2532
import org.springframework.ai.model.ModelOptionsUtils;
26-
import org.springframework.retry.RetryCallback;
27-
import org.springframework.retry.RetryContext;
28-
import org.springframework.retry.RetryListener;
33+
import org.springframework.ai.retry.RetryUtils;
2934
import org.springframework.retry.support.RetryTemplate;
3035
import org.springframework.util.Assert;
3136

32-
import java.time.Duration;
33-
import java.util.List;
34-
3537
/**
3638
* @author Ricken Bazolo
3739
* @since 0.8.1
@@ -46,40 +48,33 @@ public class MistralAiEmbeddingClient extends AbstractEmbeddingClient {
4648

4749
private final MistralAiApi mistralAiApi;
4850

49-
private final RetryTemplate retryTemplate = RetryTemplate.builder()
50-
.maxAttempts(10)
51-
.retryOn(MistralAiApiException.class)
52-
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
53-
.withListener(new RetryListener() {
54-
public <T extends Object, E extends Throwable> void onError(RetryContext context,
55-
RetryCallback<T, E> callback, Throwable throwable) {
56-
log.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
57-
};
58-
})
59-
.build();
51+
private final RetryTemplate retryTemplate;
6052

6153
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi) {
6254
this(mistralAiApi, MetadataMode.EMBED);
6355
}
6456

6557
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
6658
this(mistralAiApi, metadataMode,
67-
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build());
59+
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
60+
RetryUtils.DEFAULT_RETRY_TEMPLATE);
6861
}
6962

7063
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MistralAiEmbeddingOptions options) {
71-
this(mistralAiApi, MetadataMode.EMBED, options);
64+
this(mistralAiApi, MetadataMode.EMBED, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
7265
}
7366

7467
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode,
75-
MistralAiEmbeddingOptions options) {
68+
MistralAiEmbeddingOptions options, RetryTemplate retryTemplate) {
7669
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
7770
Assert.notNull(metadataMode, "metadataMode must not be null");
7871
Assert.notNull(options, "options must not be null");
72+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
7973

8074
this.mistralAiApi = mistralAiApi;
8175
this.metadataMode = metadataMode;
8276
this.defaultOptions = options;
77+
this.retryTemplate = retryTemplate;
8378
}
8479

8580
@Override

models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java

Lines changed: 7 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
*/
1616
package org.springframework.ai.mistralai.api;
1717

18-
import java.io.IOException;
19-
import java.nio.charset.StandardCharsets;
2018
import java.util.List;
2119
import java.util.Map;
2220
import java.util.function.Consumer;
@@ -25,22 +23,18 @@
2523
import com.fasterxml.jackson.annotation.JsonInclude;
2624
import com.fasterxml.jackson.annotation.JsonInclude.Include;
2725
import com.fasterxml.jackson.annotation.JsonProperty;
28-
import com.fasterxml.jackson.databind.DeserializationFeature;
29-
import com.fasterxml.jackson.databind.ObjectMapper;
3026
import reactor.core.publisher.Flux;
3127
import reactor.core.publisher.Mono;
3228

3329
import org.springframework.ai.model.ModelOptionsUtils;
30+
import org.springframework.ai.retry.RetryUtils;
3431
import org.springframework.boot.context.properties.bind.ConstructorBinding;
3532
import org.springframework.core.ParameterizedTypeReference;
3633
import org.springframework.http.HttpHeaders;
3734
import org.springframework.http.MediaType;
3835
import org.springframework.http.ResponseEntity;
39-
import org.springframework.http.client.ClientHttpResponse;
40-
import org.springframework.lang.NonNull;
4136
import org.springframework.util.Assert;
4237
import org.springframework.util.CollectionUtils;
43-
import org.springframework.util.StreamUtils;
4438
import org.springframework.web.client.ResponseErrorHandler;
4539
import org.springframework.web.client.RestClient;
4640
import org.springframework.web.reactive.function.client.WebClient;
@@ -70,8 +64,6 @@ public class MistralAiApi {
7064

7165
private WebClient webClient;
7266

73-
private final ObjectMapper objectMapper;
74-
7567
/**
7668
* Create a new client api with DEFAULT_BASE_URL
7769
* @param mistralAiApiKey Mistral api Key.
@@ -86,75 +78,30 @@ public MistralAiApi(String mistralAiApiKey) {
8678
* @param mistralAiApiKey Mistral api Key.
8779
*/
8880
public MistralAiApi(String baseUrl, String mistralAiApiKey) {
89-
this(baseUrl, mistralAiApiKey, RestClient.builder());
81+
this(baseUrl, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
9082
}
9183

9284
/**
9385
* Create a new client api.
9486
* @param baseUrl api base URL.
9587
* @param mistralAiApiKey Mistral api Key.
9688
* @param restClientBuilder RestClient builder.
89+
* @param responseErrorHandler Response error handler.
9790
*/
98-
public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder) {
99-
100-
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
91+
public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder,
92+
ResponseErrorHandler responseErrorHandler) {
10193

10294
Consumer<HttpHeaders> jsonContentHeaders = headers -> {
10395
headers.setBearerAuth(mistralAiApiKey);
10496
headers.setContentType(MediaType.APPLICATION_JSON);
10597
};
10698

107-
var responseErrorHandler = new ResponseErrorHandler() {
108-
109-
@Override
110-
public boolean hasError(@NonNull ClientHttpResponse response) throws IOException {
111-
return response.getStatusCode().isError();
112-
}
113-
114-
@Override
115-
public void handleError(@NonNull ClientHttpResponse response) throws IOException {
116-
if (response.getStatusCode().isError()) {
117-
String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
118-
String message = String.format("%s - %s", response.getStatusCode().value(), error);
119-
if (response.getStatusCode().is4xxClientError()) {
120-
throw new MistralAiApiClientErrorException(message);
121-
}
122-
throw new MistralAiApiException(message);
123-
}
124-
}
125-
};
126-
12799
this.restClient = restClientBuilder.baseUrl(baseUrl)
128100
.defaultHeaders(jsonContentHeaders)
129101
.defaultStatusHandler(responseErrorHandler)
130102
.build();
131-
this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
132-
}
133-
134-
public static class MistralAiApiException extends RuntimeException {
135-
136-
public MistralAiApiException(String message) {
137-
super(message);
138-
}
139-
140-
public MistralAiApiException(String message, Throwable t) {
141-
super(message, t);
142-
}
143-
144-
}
145-
146-
/**
147-
* Thrown on 4xx client errors, such as 401 - Incorrect API key provided, 401 - You
148-
* must be a member of an organization to use the API, 429 - Rate limit reached for
149-
* requests, 429 - You exceeded your current quota , please check your plan and
150-
* billing details.
151-
*/
152-
public static class MistralAiApiClientErrorException extends RuntimeException {
153-
154-
public MistralAiApiClientErrorException(String message) {
155-
super(message);
156-
}
157103

104+
this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
158105
}
159106

160107
/**
@@ -594,7 +541,7 @@ public enum ChatCompletionFinishReason {
594541
// anticipation of future changes. Based on:
595542
// https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py
596543
@JsonProperty("error") ERROR,
597-
544+
598545
@JsonProperty("tool_calls") TOOL_CALLS
599546
// @formatter:on
600547

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
*/
3131
@SpringBootTest
3232
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
33-
public class MistralChatCompletionRequestTest {
33+
public class MistralAiChatCompletionRequestTest {
3434

3535
MistralAiChatClient chatClient = new MistralAiChatClient(new MistralAiApi("test"));
3636

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralEmbeddingIT.java renamed to models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
@SpringBootTest
3030
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
31-
class MistralEmbeddingIT {
31+
class MistralAiEmbeddingIT {
3232

3333
@Autowired
3434
private MistralAiEmbeddingClient mistralAiEmbeddingClient;

0 commit comments

Comments
 (0)