diff --git a/models/spring-ai-mistral-ai/pom.xml b/models/spring-ai-mistral-ai/pom.xml
index a613e4f6a1c..07fd0e087ef 100644
--- a/models/spring-ai-mistral-ai/pom.xml
+++ b/models/spring-ai-mistral-ai/pom.xml
@@ -29,24 +29,13 @@
${project.parent.version}
-
- org.springframework
- spring-web
- ${spring-framework.version}
-
-
-
- org.springframework.retry
- spring-retry
- 2.0.4
-
-
+
+ org.springframework.ai
+ spring-ai-retry
+ ${project.parent.version}
+
-
- org.springframework
- spring-webflux
-
org.springframework
spring-context-support
diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java
index 73bffce3e49..ee212817d35 100644
--- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java
+++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java
@@ -15,7 +15,6 @@
*/
package org.springframework.ai.mistralai;
-import java.time.Duration;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -41,10 +40,8 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
-import org.springframework.retry.RetryCallback;
-import org.springframework.retry.RetryContext;
-import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@@ -70,17 +67,7 @@ public class MistralAiChatClient extends
*/
private final MistralAiApi mistralAiApi;
- private final RetryTemplate retryTemplate = RetryTemplate.builder()
- .maxAttempts(10)
- .retryOn(MistralAiApi.MistralAiApiException.class)
- .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
- .withListener(new RetryListener() {
- public void onError(RetryContext context,
- RetryCallback callback, Throwable throwable) {
- log.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
- };
- })
- .build();
+ private final RetryTemplate retryTemplate;
public MistralAiChatClient(MistralAiApi mistralAiApi) {
this(mistralAiApi,
@@ -93,46 +80,50 @@ public MistralAiChatClient(MistralAiApi mistralAiApi) {
}
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
- this(mistralAiApi, options, null);
+ this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options,
- FunctionCallbackContext functionCallbackContext) {
+ FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
super(functionCallbackContext);
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
Assert.notNull(options, "Options must not be null");
+ Assert.notNull(retryTemplate, "RetryTemplate must not be null");
this.mistralAiApi = mistralAiApi;
this.defaultOptions = options;
+ this.retryTemplate = retryTemplate;
}
@Override
public ChatResponse call(Prompt prompt) {
- // return retryTemplate.execute(ctx -> {
var request = createRequest(prompt, false);
- // var completionEntity = this.mistralAiApi.chatCompletionEntity(request);
- ResponseEntity completionEntity = this.callWithFunctionSupport(request);
+ return retryTemplate.execute(ctx -> {
- var chatCompletion = completionEntity.getBody();
- if (chatCompletion == null) {
- log.warn("No chat completion returned for prompt: {}", prompt);
- return new ChatResponse(List.of());
- }
+ ResponseEntity completionEntity = this.callWithFunctionSupport(request);
- List generations = chatCompletion.choices()
- .stream()
- .map(choice -> new Generation(choice.message().content(), Map.of("role", choice.message().role().name()))
- .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
- .toList();
+ var chatCompletion = completionEntity.getBody();
+ if (chatCompletion == null) {
+ log.warn("No chat completion returned for prompt: {}", prompt);
+ return new ChatResponse(List.of());
+ }
+
+ List generations = chatCompletion.choices()
+ .stream()
+ .map(choice -> new Generation(choice.message().content(),
+ Map.of("role", choice.message().role().name()))
+ .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
+ .toList();
- return new ChatResponse(generations);
- // });
+ return new ChatResponse(generations);
+ });
}
@Override
public Flux stream(Prompt prompt) {
+ var request = createRequest(prompt, true);
+
return retryTemplate.execute(ctx -> {
- var request = createRequest(prompt, true);
var completionChunks = this.mistralAiApi.chatCompletionStream(request);
diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java
index 30e17b36f38..e42908c348d 100644
--- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java
+++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java
@@ -15,23 +15,25 @@
*/
package org.springframework.ai.mistralai;
+import java.util.List;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
-import org.springframework.ai.embedding.*;
+import org.springframework.ai.embedding.AbstractEmbeddingClient;
+import org.springframework.ai.embedding.Embedding;
+import org.springframework.ai.embedding.EmbeddingOptions;
+import org.springframework.ai.embedding.EmbeddingRequest;
+import org.springframework.ai.embedding.EmbeddingResponse;
+import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.mistralai.api.MistralAiApi;
-import org.springframework.ai.mistralai.api.MistralAiApi.MistralAiApiException;
import org.springframework.ai.model.ModelOptionsUtils;
-import org.springframework.retry.RetryCallback;
-import org.springframework.retry.RetryContext;
-import org.springframework.retry.RetryListener;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
-import java.time.Duration;
-import java.util.List;
-
/**
* @author Ricken Bazolo
* @since 0.8.1
@@ -46,17 +48,7 @@ public class MistralAiEmbeddingClient extends AbstractEmbeddingClient {
private final MistralAiApi mistralAiApi;
- private final RetryTemplate retryTemplate = RetryTemplate.builder()
- .maxAttempts(10)
- .retryOn(MistralAiApiException.class)
- .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
- .withListener(new RetryListener() {
- public void onError(RetryContext context,
- RetryCallback callback, Throwable throwable) {
- log.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
- };
- })
- .build();
+ private final RetryTemplate retryTemplate;
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi) {
this(mistralAiApi, MetadataMode.EMBED);
@@ -64,22 +56,25 @@ public MistralAiEmbeddingClient(MistralAiApi mistralAiApi) {
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
this(mistralAiApi, metadataMode,
- MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build());
+ MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
+ RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MistralAiEmbeddingOptions options) {
- this(mistralAiApi, MetadataMode.EMBED, options);
+ this(mistralAiApi, MetadataMode.EMBED, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode,
- MistralAiEmbeddingOptions options) {
+ MistralAiEmbeddingOptions options, RetryTemplate retryTemplate) {
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");
Assert.notNull(options, "options must not be null");
+ Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.mistralAiApi = mistralAiApi;
this.metadataMode = metadataMode;
this.defaultOptions = options;
+ this.retryTemplate = retryTemplate;
}
@Override
diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java
index f227e5f521a..d2bd75c4a7c 100644
--- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java
+++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java
@@ -15,8 +15,6 @@
*/
package org.springframework.ai.mistralai.api;
-import java.io.IOException;
-import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
@@ -25,22 +23,18 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
-import com.fasterxml.jackson.databind.DeserializationFeature;
-import com.fasterxml.jackson.databind.ObjectMapper;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import org.springframework.ai.model.ModelOptionsUtils;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
-import org.springframework.http.client.ClientHttpResponse;
-import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
-import org.springframework.util.StreamUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
@@ -70,8 +64,6 @@ public class MistralAiApi {
private WebClient webClient;
- private final ObjectMapper objectMapper;
-
/**
* Create a new client api with DEFAULT_BASE_URL
* @param mistralAiApiKey Mistral api Key.
@@ -86,7 +78,7 @@ public MistralAiApi(String mistralAiApiKey) {
* @param mistralAiApiKey Mistral api Key.
*/
public MistralAiApi(String baseUrl, String mistralAiApiKey) {
- this(baseUrl, mistralAiApiKey, RestClient.builder());
+ this(baseUrl, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
/**
@@ -94,67 +86,22 @@ public MistralAiApi(String baseUrl, String mistralAiApiKey) {
* @param baseUrl api base URL.
* @param mistralAiApiKey Mistral api Key.
* @param restClientBuilder RestClient builder.
+ * @param responseErrorHandler Response error handler.
*/
- public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder) {
-
- this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+ public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder,
+ ResponseErrorHandler responseErrorHandler) {
Consumer jsonContentHeaders = headers -> {
headers.setBearerAuth(mistralAiApiKey);
headers.setContentType(MediaType.APPLICATION_JSON);
};
- var responseErrorHandler = new ResponseErrorHandler() {
-
- @Override
- public boolean hasError(@NonNull ClientHttpResponse response) throws IOException {
- return response.getStatusCode().isError();
- }
-
- @Override
- public void handleError(@NonNull ClientHttpResponse response) throws IOException {
- if (response.getStatusCode().isError()) {
- String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
- String message = String.format("%s - %s", response.getStatusCode().value(), error);
- if (response.getStatusCode().is4xxClientError()) {
- throw new MistralAiApiClientErrorException(message);
- }
- throw new MistralAiApiException(message);
- }
- }
- };
-
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(jsonContentHeaders)
.defaultStatusHandler(responseErrorHandler)
.build();
- this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
- }
-
- public static class MistralAiApiException extends RuntimeException {
-
- public MistralAiApiException(String message) {
- super(message);
- }
-
- public MistralAiApiException(String message, Throwable t) {
- super(message, t);
- }
-
- }
-
- /**
- * Thrown on 4xx client errors, such as 401 - Incorrect API key provided, 401 - You
- * must be a member of an organization to use the API, 429 - Rate limit reached for
- * requests, 429 - You exceeded your current quota , please check your plan and
- * billing details.
- */
- public static class MistralAiApiClientErrorException extends RuntimeException {
-
- public MistralAiApiClientErrorException(String message) {
- super(message);
- }
+ this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
}
/**
@@ -594,7 +541,7 @@ public enum ChatCompletionFinishReason {
// anticipation of future changes. Based on:
// https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py
@JsonProperty("error") ERROR,
-
+
@JsonProperty("tool_calls") TOOL_CALLS
// @formatter:on
diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/RetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/RetryTests.java
new file mode 100644
index 00000000000..4c6a93a900e
--- /dev/null
+++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/RetryTests.java
@@ -0,0 +1,192 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.mistralai;
+
+import java.util.List;
+import java.util.Optional;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import reactor.core.publisher.Flux;
+
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.document.MetadataMode;
+import org.springframework.ai.mistralai.api.MistralAiApi;
+import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion;
+import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk;
+import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionFinishReason;
+import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage;
+import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role;
+import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
+import org.springframework.ai.mistralai.api.MistralAiApi.Embedding;
+import org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingList;
+import org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingRequest;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.retry.TransientAiException;
+import org.springframework.http.ResponseEntity;
+import org.springframework.retry.RetryCallback;
+import org.springframework.retry.RetryContext;
+import org.springframework.retry.RetryListener;
+import org.springframework.retry.support.RetryTemplate;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Christian Tzolov
+ */
+@SuppressWarnings("unchecked")
+@ExtendWith(MockitoExtension.class)
+public class RetryTests {
+
+ private class TestRetryListener implements RetryListener {
+
+ int onErrorRetryCount = 0;
+
+ int onSuccessRetryCount = 0;
+
+ @Override
+ public void onSuccess(RetryContext context, RetryCallback callback, T result) {
+ onSuccessRetryCount = context.getRetryCount();
+ }
+
+ @Override
+ public void onError(RetryContext context, RetryCallback callback,
+ Throwable throwable) {
+ onErrorRetryCount = context.getRetryCount();
+ }
+
+ }
+
+ private TestRetryListener retryListener;
+
+ private RetryTemplate retryTemplate;
+
+ private @Mock MistralAiApi mistralAiApi;
+
+ private MistralAiChatClient chatClient;
+
+ private MistralAiEmbeddingClient embeddingClient;
+
+ @BeforeEach
+ public void beforeEach() {
+ retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
+ retryListener = new TestRetryListener();
+ retryTemplate.registerListener(retryListener);
+
+ chatClient = new MistralAiChatClient(mistralAiApi,
+ MistralAiChatOptions.builder()
+ .withTemperature(0.7f)
+ .withTopP(1f)
+ .withSafePrompt(false)
+ .withModel(MistralAiApi.ChatModel.TINY.getValue())
+ .build(),
+ null, retryTemplate);
+ embeddingClient = new MistralAiEmbeddingClient(mistralAiApi, MetadataMode.EMBED,
+ MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
+ retryTemplate);
+ }
+
+ @Test
+ public void mistralAiChatTransientError() {
+
+ var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
+ ChatCompletionFinishReason.STOP);
+ ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model",
+ List.of(choice), new MistralAiApi.Usage(10, 10, 10));
+
+ when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
+ .thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion)));
+
+ var result = chatClient.call(new Prompt("text"));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getResult().getOutput().getContent()).isSameAs("Response");
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void mistralAiChatNonTransientError() {
+ when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
+ .thenThrow(new RuntimeException("Non Transient Error"));
+ assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text")));
+ }
+
+ @Test
+ public void mistralAiChatStreamTransientError() {
+
+ var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
+ ChatCompletionFinishReason.STOP);
+ ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l,
+ "model", List.of(choice));
+
+ when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
+ .thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(Flux.just(expectedChatCompletion));
+
+ var result = chatClient.stream(new Prompt("text"));
+
+ assertThat(result).isNotNull();
+ assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response");
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void mistralAiChatStreamNonTransientError() {
+ when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
+ .thenThrow(new RuntimeException("Non Transient Error"));
+ assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text")));
+ }
+
+ @Test
+ public void mistralAiEmbeddingTransientError() {
+
+ EmbeddingList expectedEmbeddings = new EmbeddingList<>("list",
+ List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new MistralAiApi.Usage(10, 10, 10));
+
+ when(mistralAiApi.embeddings(isA(EmbeddingRequest.class)))
+ .thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings)));
+
+ var result = embeddingClient
+ .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void mistralAiEmbeddingNonTransientError() {
+ when(mistralAiApi.embeddings(isA(EmbeddingRequest.class)))
+ .thenThrow(new RuntimeException("Non Transient Error"));
+ assertThrows(RuntimeException.class, () -> embeddingClient
+ .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
+ }
+
+}
diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml
index 7fbfce8ea56..10be9e4c14f 100644
--- a/models/spring-ai-openai/pom.xml
+++ b/models/spring-ai-openai/pom.xml
@@ -29,9 +29,9 @@
- org.springframework.retry
- spring-retry
- 2.0.4
+ org.springframework.ai
+ spring-ai-retry
+ ${project.parent.version}
@@ -57,11 +57,6 @@
${victools.version}
-
- org.springframework
- spring-webflux
-
-
org.springframework
spring-context-support
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java
index 68e24808de4..30982ec12af 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java
@@ -31,8 +31,6 @@
package org.springframework.ai.openai;
-import java.time.Duration;
-
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -40,12 +38,12 @@
import org.springframework.ai.model.ModelClient;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse;
-import org.springframework.ai.openai.api.common.OpenAiApiException;
import org.springframework.ai.openai.audio.transcription.AudioTranscription;
import org.springframework.ai.openai.audio.transcription.AudioTranscriptionPrompt;
import org.springframework.ai.openai.audio.transcription.AudioTranscriptionResponse;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
@@ -66,11 +64,7 @@ public class OpenAiAudioTranscriptionClient
private final OpenAiAudioTranscriptionOptions defaultOptions;
- public final RetryTemplate retryTemplate = RetryTemplate.builder()
- .maxAttempts(10)
- .retryOn(OpenAiApiException.class)
- .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
- .build();
+ public final RetryTemplate retryTemplate;
private final OpenAiAudioApi audioApi;
@@ -80,14 +74,18 @@ public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi) {
.withModel(OpenAiAudioApi.WhisperModel.WHISPER_1.getValue())
.withResponseFormat(OpenAiAudioApi.TranscriptResponseFormat.JSON)
.withTemperature(0.7f)
- .build());
+ .build(),
+ RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
- public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options) {
+ public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options,
+ RetryTemplate retryTemplate) {
Assert.notNull(audioApi, "OpenAiAudioApi must not be null");
Assert.notNull(options, "OpenAiTranscriptionOptions must not be null");
+ Assert.notNull(retryTemplate, "RetryTemplate must not be null");
this.audioApi = audioApi;
this.defaultOptions = options;
+ this.retryTemplate = retryTemplate;
}
public String call(Resource audioResource) {
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java
index bd977a58d05..9e83ce6b07d 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java
@@ -15,7 +15,6 @@
*/
package org.springframework.ai.openai;
-import java.time.Duration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -43,14 +42,11 @@
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
-import org.springframework.ai.openai.api.common.OpenAiApiException;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
-import org.springframework.retry.RetryCallback;
-import org.springframework.retry.RetryContext;
-import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@@ -73,7 +69,7 @@ public class OpenAiChatClient extends
AbstractFunctionCallSupport>
implements ChatClient, StreamingChatClient {
- private final Logger logger = LoggerFactory.getLogger(getClass());
+ private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClient.class);
/**
* The default options used for the chat completion requests.
@@ -83,18 +79,7 @@ public class OpenAiChatClient extends
/**
* The retry template used to retry the OpenAI API calls.
*/
- public final RetryTemplate retryTemplate = RetryTemplate.builder()
- .maxAttempts(10)
- .retryOn(OpenAiApiException.class)
- .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
- .withListener(new RetryListener() {
- @Override
- public void onError(RetryContext context,
- RetryCallback callback, Throwable throwable) {
- logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
- };
- })
- .build();
+ public final RetryTemplate retryTemplate;
/**
* Low-level access to the OpenAI API.
@@ -107,24 +92,26 @@ public OpenAiChatClient(OpenAiApi openAiApi) {
}
public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options) {
- this(openAiApi, options, null);
+ this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options,
- FunctionCallbackContext functionCallbackContext) {
+ FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
super(functionCallbackContext);
Assert.notNull(openAiApi, "OpenAiApi must not be null");
Assert.notNull(options, "Options must not be null");
+ Assert.notNull(retryTemplate, "RetryTemplate must not be null");
this.openAiApi = openAiApi;
this.defaultOptions = options;
+ this.retryTemplate = retryTemplate;
}
@Override
public ChatResponse call(Prompt prompt) {
- return this.retryTemplate.execute(ctx -> {
+ ChatCompletionRequest request = createRequest(prompt, false);
- ChatCompletionRequest request = createRequest(prompt, false);
+ return this.retryTemplate.execute(ctx -> {
ResponseEntity completionEntity = this.callWithFunctionSupport(request);
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java
index 839828f9541..f5d5730d750 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java
@@ -15,7 +15,6 @@
*/
package org.springframework.ai.openai;
-import java.time.Duration;
import java.util.List;
import org.slf4j.Logger;
@@ -33,10 +32,7 @@
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
import org.springframework.ai.openai.api.OpenAiApi.Usage;
-import org.springframework.ai.openai.api.common.OpenAiApiException;
-import org.springframework.retry.RetryCallback;
-import org.springframework.retry.RetryContext;
-import org.springframework.retry.RetryListener;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
@@ -51,17 +47,7 @@ public class OpenAiEmbeddingClient extends AbstractEmbeddingClient {
private final OpenAiEmbeddingOptions defaultOptions;
- private final RetryTemplate retryTemplate = RetryTemplate.builder()
- .maxAttempts(10)
- .retryOn(OpenAiApiException.class)
- .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
- .withListener(new RetryListener() {
- public void onError(RetryContext context,
- RetryCallback callback, Throwable throwable) {
- logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
- };
- })
- .build();
+ private final RetryTemplate retryTemplate;
private final OpenAiApi openAiApi;
@@ -73,17 +59,21 @@ public OpenAiEmbeddingClient(OpenAiApi openAiApi) {
public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode) {
this(openAiApi, metadataMode,
- OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build());
+ OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
+ RetryUtils.DEFAULT_RETRY_TEMPLATE);
}
- public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options) {
+ public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options,
+ RetryTemplate retryTemplate) {
Assert.notNull(openAiApi, "OpenAiService must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");
Assert.notNull(options, "options must not be null");
+ Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.openAiApi = openAiApi;
this.metadataMode = metadataMode;
this.defaultOptions = options;
+ this.retryTemplate = retryTemplate;
}
@Override
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java
index e9020302e38..1bc78d4c23d 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java
@@ -15,7 +15,6 @@
*/
package org.springframework.ai.openai;
-import java.time.Duration;
import java.util.List;
import org.slf4j.Logger;
@@ -30,13 +29,10 @@
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.openai.api.OpenAiImageApi;
-import org.springframework.ai.openai.api.common.OpenAiApiException;
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
-import org.springframework.retry.RetryCallback;
-import org.springframework.retry.RetryContext;
-import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
@@ -50,38 +46,32 @@
*/
public class OpenAiImageClient implements ImageClient {
- private final Logger logger = LoggerFactory.getLogger(getClass());
+ private final static Logger logger = LoggerFactory.getLogger(OpenAiImageClient.class);
private OpenAiImageOptions defaultOptions;
private final OpenAiImageApi openAiImageApi;
- public final RetryTemplate retryTemplate = RetryTemplate.builder()
- .maxAttempts(10)
- .retryOn(OpenAiApiException.class)
- .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
- .withListener(new RetryListener() {
- public void onError(RetryContext context,
- RetryCallback callback, Throwable throwable) {
- logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
- };
- })
- .build();
+ public final RetryTemplate retryTemplate;
public OpenAiImageClient(OpenAiImageApi openAiImageApi) {
+ this(openAiImageApi, OpenAiImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
+ }
+
+ public OpenAiImageClient(OpenAiImageApi openAiImageApi, OpenAiImageOptions defaultOptions,
+ RetryTemplate retryTemplate) {
Assert.notNull(openAiImageApi, "OpenAiImageApi must not be null");
+ Assert.notNull(defaultOptions, "defaultOptions must not be null");
+ Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.openAiImageApi = openAiImageApi;
+ this.defaultOptions = defaultOptions;
+ this.retryTemplate = retryTemplate;
}
public OpenAiImageOptions getDefaultOptions() {
return this.defaultOptions;
}
- public OpenAiImageClient withDefaultOptions(OpenAiImageOptions defaultOptions) {
- this.defaultOptions = defaultOptions;
- return this;
- }
-
@Override
public ImageResponse call(ImagePrompt imagePrompt) {
return this.retryTemplate.execute(ctx -> {
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ApiUtils.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ApiUtils.java
new file mode 100644
index 00000000000..36ad4b8f758
--- /dev/null
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ApiUtils.java
@@ -0,0 +1,37 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.openai.api;
+
+import java.util.function.Consumer;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+
+/**
+ * @author Christian Tzolov
+ */
+public class ApiUtils {
+
+ public static final String DEFAULT_BASE_URL = "https://api.openai.com";
+
+ public static Consumer getJsonContentHeaders(String apiKey) {
+ return (headers) -> {
+ headers.setBearerAuth(apiKey);
+ headers.setContentType(MediaType.APPLICATION_JSON);
+ };
+ };
+
+}
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java
index 647f06c8d1b..795014139cd 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java
@@ -26,14 +26,13 @@
import reactor.core.publisher.Mono;
import org.springframework.ai.model.ModelOptionsUtils;
-import org.springframework.ai.openai.api.common.ApiUtils;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
import org.springframework.core.ParameterizedTypeReference;
-import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
-import org.springframework.util.MultiValueMap;
+import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
@@ -53,8 +52,6 @@ public class OpenAiApi {
private final RestClient restClient;
- private final RestClient multipartRestClient;
-
private final WebClient webClient;
/**
@@ -84,20 +81,23 @@ public OpenAiApi(String baseUrl, String openAiToken) {
* @param restClientBuilder RestClient builder.
*/
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
+ this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
+ }
+
+ /**
+ * Create a new chat completion api.
+ *
+ * @param baseUrl api base URL.
+ * @param openAiToken OpenAI apiKey.
+ * @param restClientBuilder RestClient builder.
+ * @param responseErrorHandler Response error handler.
+ */
+ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
this.restClient = restClientBuilder
.baseUrl(baseUrl)
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
- .defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
- .build();
-
- this.multipartRestClient = restClientBuilder
- .baseUrl(baseUrl)
- .defaultHeaders(multipartFormDataHeaders -> {
- multipartFormDataHeaders.setBearerAuth(openAiToken);
- multipartFormDataHeaders.setContentType(MediaType.MULTIPART_FORM_DATA);
- })
- .defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
+ .defaultStatusHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
.build();
this.webClient = WebClient.builder()
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java
index fcdd1fd4d13..65e5ca310af 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java
@@ -21,12 +21,13 @@
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
-import org.springframework.ai.openai.api.common.ApiUtils;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
+import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
/**
@@ -45,7 +46,7 @@ public class OpenAiAudioApi {
* @param openAiToken OpenAI apiKey.
*/
public OpenAiAudioApi(String openAiToken) {
- this(ApiUtils.DEFAULT_BASE_URL, openAiToken, RestClient.builder());
+ this(ApiUtils.DEFAULT_BASE_URL, openAiToken, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
/**
@@ -53,12 +54,14 @@ public OpenAiAudioApi(String openAiToken) {
* @param baseUrl api base URL.
* @param openAiToken OpenAI apiKey.
* @param restClientBuilder RestClient builder.
+ * @param responseErrorHandler Response error handler.
*/
- public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
+ public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder,
+ ResponseErrorHandler responseErrorHandler) {
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(headers -> {
headers.setBearerAuth(openAiToken);
- }).defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER).build();
+ }).defaultStatusHandler(responseErrorHandler).build();
}
/**
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java
index 6499f51174f..18b3d0bff4e 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java
+++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java
@@ -20,9 +20,10 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
-import org.springframework.ai.openai.api.common.ApiUtils;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
+import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
/**
@@ -44,11 +45,29 @@ public OpenAiImageApi(String openAiToken) {
this(ApiUtils.DEFAULT_BASE_URL, openAiToken, RestClient.builder());
}
+ /**
+ * Create a new OpenAI Image API with the provided base URL.
+ * @param baseUrl the base URL for the OpenAI API.
+ * @param openAiToken OpenAI apiKey.
+ * @param restClientBuilder the rest client builder to use.
+ */
public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
+ this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
+ }
+
+ /**
+ * Create a new OpenAI Image API with the provided base URL.
+ * @param baseUrl the base URL for the OpenAI API.
+ * @param openAiToken OpenAI apiKey.
+ * @param restClientBuilder the rest client builder to use.
+ * @param responseErrorHandler the response error handler to use.
+ */
+ public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder,
+ ResponseErrorHandler responseErrorHandler) {
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken))
- .defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
+ .defaultStatusHandler(responseErrorHandler)
.build();
}
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java
index a3190312d7f..300cdfafbb8 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java
@@ -22,7 +22,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.openai.api.OpenAiApi;
-import org.springframework.ai.openai.chat.api.tool.MockWeatherService;
+import org.springframework.ai.openai.api.tool.MockWeatherService;
import static org.assertj.core.api.Assertions.assertThat;
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java
similarity index 98%
rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/OpenAiApiIT.java
rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java
index e54b7e87792..182a11dcec2 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/OpenAiApiIT.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.openai.chat.api;
+package org.springframework.ai.openai.api;
import java.util.List;
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/RestClientBuilderTests.java
similarity index 96%
rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java
rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/RestClientBuilderTests.java
index 96c45d56bf6..ffdccf035fb 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/RestClientBuilderTests.java
@@ -13,9 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.openai.chat.api;
+package org.springframework.ai.openai.api;
-import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java
similarity index 97%
rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java
rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java
index e1364756ac3..db41af1f0d4 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.openai.chat.api.tool;
+package org.springframework.ai.openai.api.tool;
import java.util.function.Function;
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java
similarity index 99%
rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java
rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java
index b3ac24b09d9..4138a24a98c 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java
@@ -13,7 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.openai.chat.api.tool;
+
+package org.springframework.ai.openai.api.tool;
import java.util.ArrayList;
import java.util.List;
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java
index 8d99b60fe13..3b2629714f8 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java
@@ -26,6 +26,7 @@
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionMetadata;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders;
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.autoconfigure.web.client.RestClientTest;
@@ -151,7 +152,7 @@ static class Config {
@Bean
public OpenAiAudioApi chatCompletionApi(RestClient.Builder builder) {
- return new OpenAiAudioApi("", TEST_API_KEY, builder);
+ return new OpenAiAudioApi("", TEST_API_KEY, builder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}
@Bean
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java
index a5f754a2ce7..579502daff1 100644
--- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java
@@ -37,7 +37,7 @@
import org.springframework.ai.model.function.FunctionCallbackWrapper;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiTestConfiguration;
-import org.springframework.ai.openai.chat.api.tool.MockWeatherService;
+import org.springframework.ai.openai.api.tool.MockWeatherService;
import org.springframework.ai.openai.testutils.AbstractIT;
import org.springframework.ai.parser.BeanOutputParser;
import org.springframework.ai.parser.ListOutputParser;
diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/RetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/RetryTests.java
new file mode 100644
index 00000000000..4abe2625c3d
--- /dev/null
+++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/RetryTests.java
@@ -0,0 +1,273 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.openai.chat;
+
+import java.util.List;
+import java.util.Optional;
+
+import groovyjarjarpicocli.CommandLine.TraceLevel;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import reactor.core.publisher.Flux;
+
+import org.springframework.ai.chat.prompt.Prompt;
+import org.springframework.ai.document.MetadataMode;
+import org.springframework.ai.image.ImageMessage;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.openai.OpenAiAudioTranscriptionClient;
+import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions;
+import org.springframework.ai.openai.OpenAiChatClient;
+import org.springframework.ai.openai.OpenAiChatOptions;
+import org.springframework.ai.openai.OpenAiEmbeddingClient;
+import org.springframework.ai.openai.OpenAiEmbeddingOptions;
+import org.springframework.ai.openai.OpenAiImageClient;
+import org.springframework.ai.openai.OpenAiImageOptions;
+import org.springframework.ai.openai.api.OpenAiApi;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionFinishReason;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
+import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
+import org.springframework.ai.openai.api.OpenAiApi.Embedding;
+import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
+import org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest;
+import org.springframework.ai.openai.api.OpenAiAudioApi;
+import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse;
+import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat;
+import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest;
+import org.springframework.ai.openai.api.OpenAiImageApi;
+import org.springframework.ai.openai.api.OpenAiImageApi.Data;
+import org.springframework.ai.openai.api.OpenAiImageApi.OpenAiImageRequest;
+import org.springframework.ai.openai.api.OpenAiImageApi.OpenAiImageResponse;
+import org.springframework.ai.openai.audio.transcription.AudioTranscriptionPrompt;
+import org.springframework.ai.openai.audio.transcription.AudioTranscriptionResponse;
+import org.springframework.ai.retry.RetryUtils;
+import org.springframework.ai.retry.TransientAiException;
+import org.springframework.core.io.ClassPathResource;
+import org.springframework.http.ResponseEntity;
+import org.springframework.retry.RetryCallback;
+import org.springframework.retry.RetryContext;
+import org.springframework.retry.RetryListener;
+import org.springframework.retry.support.RetryTemplate;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Christian Tzolov
+ */
+@SuppressWarnings("unchecked")
+@ExtendWith(MockitoExtension.class)
+public class RetryTests {
+
+ private class TestRetryListener implements RetryListener {
+
+ int onErrorRetryCount = 0;
+
+ int onSuccessRetryCount = 0;
+
+ @Override
+ public void onSuccess(RetryContext context, RetryCallback callback, T result) {
+ onSuccessRetryCount = context.getRetryCount();
+ }
+
+ @Override
+ public void onError(RetryContext context, RetryCallback callback,
+ Throwable throwable) {
+ onErrorRetryCount = context.getRetryCount();
+ }
+
+ }
+
+ private TestRetryListener retryListener;
+
+ private RetryTemplate retryTemplate;
+
+ private @Mock OpenAiApi openAiApi;
+
+ private @Mock OpenAiAudioApi openAiAudioApi;
+
+ private @Mock OpenAiImageApi openAiImageApi;
+
+ private OpenAiChatClient chatClient;
+
+ private OpenAiEmbeddingClient embeddingClient;
+
+ private OpenAiAudioTranscriptionClient audioTranscriptionClient;
+
+ private OpenAiImageClient imageClient;
+
+ @BeforeEach
+ public void beforeEach() {
+ retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
+ retryListener = new TestRetryListener();
+ retryTemplate.registerListener(retryListener);
+
+ chatClient = new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder().build(), null, retryTemplate);
+ embeddingClient = new OpenAiEmbeddingClient(openAiApi, MetadataMode.EMBED,
+ OpenAiEmbeddingOptions.builder().build(), retryTemplate);
+ audioTranscriptionClient = new OpenAiAudioTranscriptionClient(openAiAudioApi,
+ OpenAiAudioTranscriptionOptions.builder()
+ .withModel("model")
+ .withResponseFormat(TranscriptResponseFormat.JSON)
+ .build(),
+ retryTemplate);
+ imageClient = new OpenAiImageClient(openAiImageApi, OpenAiImageOptions.builder().build(), retryTemplate);
+ }
+
+ @Test
+ public void openAiChatTransientError() {
+
+ var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0,
+ new ChatCompletionMessage("Response", Role.ASSISTANT), null);
+ ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null,
+ new OpenAiApi.Usage(10, 10, 10));
+
+ when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
+ .thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion)));
+
+ var result = chatClient.call(new Prompt("text"));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getResult().getOutput().getContent()).isSameAs("Response");
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void openAiChatNonTransientError() {
+ when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class)))
+ .thenThrow(new RuntimeException("Non Transient Error"));
+ assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text")));
+ }
+
+ @Test
+ public void openAiChatStreamTransientError() {
+
+ var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0,
+ new ChatCompletionMessage("Response", Role.ASSISTANT), null);
+ ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null,
+ null);
+
+ when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
+ .thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(Flux.just(expectedChatCompletion));
+
+ var result = chatClient.stream(new Prompt("text"));
+
+ assertThat(result).isNotNull();
+ assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response");
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void openAiChatStreamNonTransientError() {
+ when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class)))
+ .thenThrow(new RuntimeException("Non Transient Error"));
+ assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text")));
+ }
+
+ @Test
+ public void openAiEmbeddingTransientError() {
+
+ EmbeddingList expectedEmbeddings = new EmbeddingList<>("list",
+ List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new OpenAiApi.Usage(10, 10, 10));
+
+ when(openAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings)));
+
+ var result = embeddingClient
+ .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8));
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void openAiEmbeddingNonTransientError() {
+ when(openAiApi.embeddings(isA(EmbeddingRequest.class)))
+ .thenThrow(new RuntimeException("Non Transient Error"));
+ assertThrows(RuntimeException.class, () -> embeddingClient
+ .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)));
+ }
+
+ @Test
+ public void openAiAudioTranscriptionTransientError() {
+
+ var expectedResponse = new StructuredResponse("nl", 6.7f, "Transcription Text", List.of(), List.of());
+
+ when(openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class)))
+ .thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(ResponseEntity.of(Optional.of(expectedResponse)));
+
+ AudioTranscriptionResponse result = audioTranscriptionClient
+ .call(new AudioTranscriptionPrompt(new ClassPathResource("speech/jfk.flac")));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getResult().getOutput()).isEqualTo(expectedResponse.text());
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void openAiAudioTranscriptionNonTransientError() {
+ when(openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class)))
+ .thenThrow(new RuntimeException("Transient Error 1"));
+ assertThrows(RuntimeException.class, () -> audioTranscriptionClient
+ .call(new AudioTranscriptionPrompt(new ClassPathResource("speech/jfk.flac"))));
+ }
+
+ @Test
+ public void openAiImageTransientError() {
+
+ var expectedResponse = new OpenAiImageResponse(678l, List.of(new Data("url678", "b64", "prompt")));
+
+ when(openAiImageApi.createImage(isA(OpenAiImageRequest.class)))
+ .thenThrow(new TransientAiException("Transient Error 1"))
+ .thenThrow(new TransientAiException("Transient Error 2"))
+ .thenReturn(ResponseEntity.of(Optional.of(expectedResponse)));
+
+ var result = imageClient.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))));
+
+ assertThat(result).isNotNull();
+ assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678");
+ assertThat(retryListener.onSuccessRetryCount).isEqualTo(2);
+ assertThat(retryListener.onErrorRetryCount).isEqualTo(2);
+ }
+
+ @Test
+ public void openAiImageNonTransientError() {
+ when(openAiImageApi.createImage(isA(OpenAiImageRequest.class)))
+ .thenThrow(new RuntimeException("Transient Error 1"));
+ assertThrows(RuntimeException.class,
+ () -> imageClient.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))));
+ }
+
+}
diff --git a/models/spring-ai-stability-ai/pom.xml b/models/spring-ai-stability-ai/pom.xml
index b14b51a17ec..b5180e6faf0 100644
--- a/models/spring-ai-stability-ai/pom.xml
+++ b/models/spring-ai-stability-ai/pom.xml
@@ -30,12 +30,11 @@
- org.springframework
- spring-web
- ${spring-framework.version}
+ org.springframework.ai
+ spring-ai-retry
+ ${project.parent.version}
-
org.springframework
diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java
index d39594998b5..e632e5235b5 100644
--- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java
+++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java
@@ -15,17 +15,24 @@
*/
package org.springframework.ai.stabilityai;
+import java.util.List;
+import java.util.stream.Collectors;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.springframework.ai.image.*;
+
+import org.springframework.ai.image.Image;
+import org.springframework.ai.image.ImageClient;
+import org.springframework.ai.image.ImageGeneration;
+import org.springframework.ai.image.ImageOptions;
+import org.springframework.ai.image.ImagePrompt;
+import org.springframework.ai.image.ImageResponse;
+import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.ai.stabilityai.api.StabilityAiImageOptions;
import org.springframework.util.Assert;
-import java.util.List;
-import java.util.stream.Collectors;
-
/**
* StabilityAiImageClient is a class that implements the ImageClient interface. It
* provides a client for calling the StabilityAI image generation API.
@@ -50,7 +57,7 @@ public StabilityAiImageClient(StabilityAiApi stabilityAiApi, StabilityAiImageOpt
}
public StabilityAiImageOptions getOptions() {
- return options;
+ return this.options;
}
/**
@@ -159,17 +166,4 @@ private StabilityAiImageOptions convertOptions(ImageOptions runtimeOptions) {
return builder.build();
}
- private ImagePrompt createUpdatedPrompt(ImagePrompt prompt) {
- ImageOptions runtimeImageModelOptions = prompt.getOptions();
- ImageOptionsBuilder imageOptionsBuilder = ImageOptionsBuilder.builder();
-
- if (runtimeImageModelOptions != null) {
- if (runtimeImageModelOptions.getModel() != null) {
- imageOptionsBuilder.withModel(runtimeImageModelOptions.getModel());
- }
- }
- ImageOptions updatedImageModelOptions = imageOptionsBuilder.build();
- return new ImagePrompt(prompt.getInstructions(), updatedImageModelOptions);
- }
-
}
diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java
index 680f4ee3620..648dae1a5d9 100644
--- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java
+++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java
@@ -35,16 +35,17 @@ public StabilityAiImageGenerationMetadata(String finishReason, Long seed) {
}
public String getFinishReason() {
- return finishReason;
+ return this.finishReason;
}
public Long getSeed() {
- return seed;
+ return this.seed;
}
@Override
public String toString() {
- return "StabilityAiImageGenerationMetadata{" + "finishReason='" + finishReason + '\'' + ", seed=" + seed + '}';
+ return "StabilityAiImageGenerationMetadata{" + "finishReason='" + this.finishReason + '\'' + ", seed="
+ + this.seed + '}';
}
@Override
@@ -53,12 +54,12 @@ public boolean equals(Object o) {
return true;
if (!(o instanceof StabilityAiImageGenerationMetadata that))
return false;
- return Objects.equals(finishReason, that.finishReason) && Objects.equals(seed, that.seed);
+ return Objects.equals(this.finishReason, that.finishReason) && Objects.equals(this.seed, that.seed);
}
@Override
public int hashCode() {
- return Objects.hash(finishReason, seed);
+ return Objects.hash(this.finishReason, this.seed);
}
}
diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java
index 15c8645e481..e1d7c9efa5a 100644
--- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java
+++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java
@@ -20,11 +20,25 @@
*/
public enum StyleEnum {
- THREE_D_MODEL("3d-model"), ANALOG_FILM("analog-film"), ANIME("anime"), CINEMATIC("cinematic"),
- COMIC_BOOK("comic-book"), DIGITAL_ART("digital-art"), ENHANCE("enhance"), FANTASY_ART("fantasy-art"),
- ISOMETRIC("isometric"), LINE_ART("line-art"), LOW_POLY("low-poly"), MODELING_COMPOUND("modeling-compound"),
- NEON_PUNK("neon-punk"), ORIGAMI("origami"), PHOTOGRAPHIC("photographic"), PIXEL_ART("pixel-art"),
+ // @formatter:off
+ THREE_D_MODEL("3d-model"),
+ ANALOG_FILM("analog-film"),
+ ANIME("anime"),
+ CINEMATIC("cinematic"),
+ COMIC_BOOK("comic-book"),
+ DIGITAL_ART("digital-art"),
+ ENHANCE("enhance"),
+ FANTASY_ART("fantasy-art"),
+ ISOMETRIC("isometric"),
+ LINE_ART("line-art"),
+ LOW_POLY("low-poly"),
+ MODELING_COMPOUND("modeling-compound"),
+ NEON_PUNK("neon-punk"),
+ ORIGAMI("origami"),
+ PHOTOGRAPHIC("photographic"),
+ PIXEL_ART("pixel-art"),
TILE_TEXTURE("tile-texture");
+ // @formatter:on
private final String text;
diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java
index 984098a1c8a..2ee5b2f7f24 100644
--- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java
+++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java
@@ -15,20 +15,18 @@
*/
package org.springframework.ai.stabilityai.api;
+import java.util.List;
+import java.util.function.Consumer;
+
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
-import com.fasterxml.jackson.databind.ObjectMapper;
+
+import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
-import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
-import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
-import java.io.IOException;
-import java.util.List;
-import java.util.function.Consumer;
-
/**
* Represents the StabilityAI API.
*/
@@ -80,35 +78,12 @@ public StabilityAiApi(String apiKey, String model, String baseUrl, RestClient.Bu
headers.setContentType(MediaType.APPLICATION_JSON);
};
- ResponseErrorHandler responseErrorHandler = new ResponseErrorHandler() {
- @Override
- public boolean hasError(ClientHttpResponse response) throws IOException {
- return response.getStatusCode().isError();
- }
-
- @Override
- public void handleError(ClientHttpResponse response) throws IOException {
- if (response.getStatusCode().isError()) {
- throw new RuntimeException(String.format("%s - %s", response.getStatusCode().value(),
- new ObjectMapper().readValue(response.getBody(), ResponseError.class)));
- }
- }
- };
-
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(jsonContentHeaders)
- .defaultStatusHandler(responseErrorHandler)
+ .defaultStatusHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER)
.build();
}
- @JsonInclude(JsonInclude.Include.NON_NULL)
- public record ResponseError(@JsonProperty("id") String id, @JsonProperty("name") String name,
- @JsonProperty("message") String message
-
- ) {
-
- }
-
@JsonInclude(JsonInclude.Include.NON_NULL)
public record GenerateImageRequest(@JsonProperty("text_prompts") List textPrompts,
@JsonProperty("height") Integer height, @JsonProperty("width") Integer width,
diff --git a/pom.xml b/pom.xml
index 57a62fd27e6..da31f6831f7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -59,6 +59,7 @@
vector-stores/spring-ai-qdrant
spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai
spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai
+ spring-ai-retry
@@ -115,6 +116,7 @@
1.17.0
26.33.0
1.7.1
+ 2.0.5
3.25.2
@@ -297,6 +299,8 @@ limitations under the License.
+ **/.antlr/**
+ **/aot.factories
**/.sdkmanrc
**/*.adoc
**/*.puml
diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml
index e838b6c080e..4f604f768e7 100644
--- a/spring-ai-bom/pom.xml
+++ b/spring-ai-bom/pom.xml
@@ -33,12 +33,18 @@
${project.version}
-
-
- org.springframework.ai
- spring-ai-pdf-document-reader
- ${project.version}
-
+
+ org.springframework.ai
+ spring-ai-retry
+ ${project.parent.version}
+
+
+
+
+ org.springframework.ai
+ spring-ai-pdf-document-reader
+ ${project.version}
+
org.springframework.ai
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/image/openai-image.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/image/openai-image.adoc
index 3ac1ccd5b71..46f0decd7f6 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/image/openai-image.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/image/openai-image.adoc
@@ -41,6 +41,19 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man
=== Image Generation Properties
+The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI Chat client.
+
+[cols="3,5,1"]
+|====
+| Property | Description | Default
+
+| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10
+| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec.
+| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5
+| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min.
+|====
+
+
The prefix `spring.ai.openai` is used as the property prefix that lets you connect to OpenAI.
[cols="3,5,1"]
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/mistralai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/mistralai-chat.adoc
index faf9d3b9720..e540bae43f7 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/mistralai-chat.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/mistralai-chat.adoc
@@ -49,6 +49,18 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man
=== Chat Properties
+The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI Chat client.
+
+[cols="3,5,1"]
+|====
+| Property | Description | Default
+
+| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10
+| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec.
+| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5
+| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min.
+|====
+
The prefix `spring.ai.mistralai` is used as the property prefix that lets you connect to OpenAI.
[cols="3,5,1"]
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc
index 2d184853c39..55a767a7d1e 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc
@@ -49,6 +49,18 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man
=== Chat Properties
+The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI Chat client.
+
+[cols="3,5,1"]
+|====
+| Property | Description | Default
+
+| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10
+| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec.
+| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5
+| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min.
+|====
+
The prefix `spring.ai.openai` is used as the property prefix that lets you connect to OpenAI.
[cols="3,5,1"]
diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc
index 9a9d3c2e5b7..57f19e66b36 100644
--- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc
+++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc
@@ -50,6 +50,18 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man
=== Embedding Properties
+The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI Chat client.
+
+[cols="3,5,1"]
+|====
+| Property | Description | Default
+
+| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10
+| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec.
+| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5
+| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min.
+|====
+
The prefix `spring.ai.openai` is used as the property prefix that lets you connect to OpenAI.
[cols="3,5,1"]
diff --git a/spring-ai-retry/pom.xml b/spring-ai-retry/pom.xml
new file mode 100644
index 00000000000..9c2327c9037
--- /dev/null
+++ b/spring-ai-retry/pom.xml
@@ -0,0 +1,50 @@
+
+
+ 4.0.0
+
+ org.springframework.ai
+ spring-ai
+ 0.8.1-SNAPSHOT
+
+ spring-ai-retry
+ jar
+ Spring AI Retry
+ Spring AI utility project helping with remote call retry
+ https://github.com/spring-projects/spring-ai
+
+
+ https://github.com/spring-projects/spring-ai
+ git://github.com/spring-projects/spring-ai.git
+ git@github.com:spring-projects/spring-ai.git
+
+
+
+
+
+
+ org.springframework.ai
+ spring-ai-core
+ ${project.parent.version}
+
+
+
+ org.springframework.retry
+ spring-retry
+ ${spring-retry.version}
+
+
+
+ org.springframework
+ spring-webflux
+
+
+
+
+ org.springframework.boot
+ spring-boot-starter-test
+ test
+
+
+
+
diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java
new file mode 100644
index 00000000000..55dbe5e3d74
--- /dev/null
+++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.retry;
+
+/**
+ * Root of the hierarchy of Model access exceptions that are considered non-transient -
+ * where a retry of the same operation would fail unless the cause of the Exception is
+ * corrected.
+ *
+ * @author Christian Tzolov
+ */
+public class NonTransientAiException extends RuntimeException {
+
+ public NonTransientAiException(String message) {
+ super(message);
+ }
+
+ public NonTransientAiException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+}
diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/ApiUtils.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java
similarity index 53%
rename from models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/ApiUtils.java
rename to spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java
index 8219170532d..ec6b2d671fb 100644
--- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/ApiUtils.java
+++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java
@@ -13,32 +13,43 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.springframework.ai.openai.api.common;
+package org.springframework.ai.retry;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
-import java.util.function.Consumer;
+import java.time.Duration;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-import org.springframework.http.HttpHeaders;
-import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.lang.NonNull;
+import org.springframework.retry.RetryCallback;
+import org.springframework.retry.RetryContext;
+import org.springframework.retry.RetryListener;
+import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.ResponseErrorHandler;
/**
- * @author Christian Tzolov
+ *
*/
-public class ApiUtils {
+public class RetryUtils {
- public static final String DEFAULT_BASE_URL = "https://api.openai.com";
+ private static final Logger logger = LoggerFactory.getLogger(RetryUtils.class);
- public static Consumer getJsonContentHeaders(String apiKey) {
- return (headers) -> {
- headers.setBearerAuth(apiKey);
- headers.setContentType(MediaType.APPLICATION_JSON);
- };
- };
+ public static final RetryTemplate DEFAULT_RETRY_TEMPLATE = RetryTemplate.builder()
+ .maxAttempts(10)
+ .retryOn(TransientAiException.class)
+ .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
+ .withListener(new RetryListener() {
+ @Override
+ public void onError(RetryContext context,
+ RetryCallback callback, Throwable throwable) {
+ logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
+ };
+ })
+ .build();
public static final ResponseErrorHandler DEFAULT_RESPONSE_ERROR_HANDLER = new ResponseErrorHandler() {
@@ -52,10 +63,16 @@ public void handleError(@NonNull ClientHttpResponse response) throws IOException
if (response.getStatusCode().isError()) {
String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
String message = String.format("%s - %s", response.getStatusCode().value(), error);
+ /**
+ * Thrown on 4xx client errors, such as 401 - Incorrect API key provided,
+ * 401 - You must be a member of an organization to use the API, 429 -
+ * Rate limit reached for requests, 429 - You exceeded your current quota
+ * , please check your plan and billing details.
+ */
if (response.getStatusCode().is4xxClientError()) {
- throw new OpenAiApiClientErrorException(message);
+ throw new NonTransientAiException(message);
}
- throw new OpenAiApiException(message);
+ throw new TransientAiException(message);
}
}
};
diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java
new file mode 100644
index 00000000000..4da1960255f
--- /dev/null
+++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java
@@ -0,0 +1,35 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.retry;
+
+/**
+ * Root of the hierarchy of Model access exceptions that are considered transient - where
+ * a previously failed operation might be able to succeed when the operation is retried
+ * without any intervention by application-level functionality.
+ *
+ * @author Christian Tzolov
+ */
+public class TransientAiException extends RuntimeException {
+
+ public TransientAiException(String message) {
+ super(message);
+ }
+
+ public TransientAiException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java
index a3f67f8864d..1328f3488ba 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java
@@ -17,6 +17,7 @@
import java.util.List;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.mistralai.MistralAiChatClient;
import org.springframework.ai.mistralai.MistralAiEmbeddingClient;
import org.springframework.ai.mistralai.api.MistralAiApi;
@@ -30,9 +31,11 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
+import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
+import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
/**
@@ -40,7 +43,7 @@
* @author Christian Tzolov
* @since 0.8.1
*/
-@AutoConfiguration(after = { RestClientAutoConfiguration.class })
+@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class })
@EnableConfigurationProperties({ MistralAiEmbeddingProperties.class, MistralAiCommonProperties.class,
MistralAiChatProperties.class })
@ConditionalOnClass(MistralAiApi.class)
@@ -51,13 +54,15 @@ public class MistralAiAutoConfiguration {
@ConditionalOnProperty(prefix = MistralAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public MistralAiEmbeddingClient mistralAiEmbeddingClient(MistralAiCommonProperties commonProperties,
- MistralAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder) {
+ MistralAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder,
+ RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {
var mistralAiApi = mistralAiApi(embeddingProperties.getApiKey(), commonProperties.getApiKey(),
- embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder);
+ embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder,
+ responseErrorHandler);
return new MistralAiEmbeddingClient(mistralAiApi, embeddingProperties.getMetadataMode(),
- embeddingProperties.getOptions());
+ embeddingProperties.getOptions(), retryTemplate);
}
@Bean
@@ -66,20 +71,22 @@ public MistralAiEmbeddingClient mistralAiEmbeddingClient(MistralAiCommonProperti
matchIfMissing = true)
public MistralAiChatClient mistralAiChatClient(MistralAiCommonProperties commonProperties,
MistralAiChatProperties chatProperties, RestClient.Builder restClientBuilder,
- List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) {
+ List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext,
+ RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {
var mistralAiApi = mistralAiApi(chatProperties.getApiKey(), commonProperties.getApiKey(),
- chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder);
+ chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder, responseErrorHandler);
if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
}
- return new MistralAiChatClient(mistralAiApi, chatProperties.getOptions(), functionCallbackContext);
+ return new MistralAiChatClient(mistralAiApi, chatProperties.getOptions(), functionCallbackContext,
+ retryTemplate);
}
private MistralAiApi mistralAiApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl,
- RestClient.Builder restClientBuilder) {
+ RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey;
var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
@@ -87,7 +94,7 @@ private MistralAiApi mistralAiApi(String apiKey, String commonApiKey, String bas
Assert.hasText(resolvedApiKey, "Mistral API key must be set");
Assert.hasText(resoledBaseUrl, "Mistral base URL must be set");
- return new MistralAiApi(resoledBaseUrl, resolvedApiKey, restClientBuilder);
+ return new MistralAiApi(resoledBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler);
}
@Bean
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java
index a6d2a56260e..b9899440641 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java
@@ -17,6 +17,7 @@
import java.util.List;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
@@ -35,15 +36,17 @@
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
+import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
+import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
/**
* @author Christian Tzolov
*/
-@AutoConfiguration(after = { RestClientAutoConfiguration.class })
+@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class })
@ConditionalOnClass(OpenAiApi.class)
@EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class,
OpenAiEmbeddingProperties.class, OpenAiImageProperties.class, OpenAiAudioTranscriptionProperties.class })
@@ -55,16 +58,17 @@ public class OpenAiAutoConfiguration {
matchIfMissing = true)
public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProperties,
OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder,
- List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) {
+ List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext,
+ RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {
var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(),
- chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder);
+ chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);
if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) {
chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks);
}
- return new OpenAiChatClient(openAiApi, chatProperties.getOptions(), functionCallbackContext);
+ return new OpenAiChatClient(openAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate);
}
@Bean
@@ -72,17 +76,18 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper
@ConditionalOnProperty(prefix = OpenAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public EmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties commonProperties,
- OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder) {
+ OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder,
+ RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) {
var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(),
- embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder);
+ embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler);
return new OpenAiEmbeddingClient(openAiApi, embeddingProperties.getMetadataMode(),
- embeddingProperties.getOptions());
+ embeddingProperties.getOptions(), retryTemplate);
}
private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey,
- RestClient.Builder restClientBuilder) {
+ RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl;
Assert.hasText(resolvedBaseUrl, "OpenAI base URL must be set");
@@ -90,7 +95,7 @@ private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey,
String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey;
Assert.hasText(resolvedApiKey, "OpenAI API key must be set");
- return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder);
+ return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler);
}
@Bean
@@ -98,7 +103,9 @@ private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey,
@ConditionalOnProperty(prefix = OpenAiImageProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true",
matchIfMissing = true)
public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProperties,
- OpenAiImageProperties imageProperties, RestClient.Builder restClientBuilder) {
+ OpenAiImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate,
+ ResponseErrorHandler responseErrorHandler) {
+
String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey()
: commonProperties.getApiKey();
@@ -108,15 +115,16 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp
Assert.hasText(apiKey, "OpenAI API key must be set");
Assert.hasText(baseUrl, "OpenAI base URL must be set");
- var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder);
+ var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler);
- return new OpenAiImageClient(openAiImageApi).withDefaultOptions(imageProperties.getOptions());
+ return new OpenAiImageClient(openAiImageApi, imageProperties.getOptions(), retryTemplate);
}
@Bean
@ConditionalOnMissingBean
public OpenAiAudioTranscriptionClient openAiAudioTranscriptionClient(OpenAiConnectionProperties commonProperties,
- OpenAiAudioTranscriptionProperties transcriptionProperties) {
+ OpenAiAudioTranscriptionProperties transcriptionProperties, RetryTemplate retryTemplate,
+ ResponseErrorHandler responseErrorHandler) {
String apiKey = StringUtils.hasText(transcriptionProperties.getApiKey()) ? transcriptionProperties.getApiKey()
: commonProperties.getApiKey();
@@ -127,10 +135,10 @@ public OpenAiAudioTranscriptionClient openAiAudioTranscriptionClient(OpenAiConne
Assert.hasText(apiKey, "OpenAI API key must be set");
Assert.hasText(baseUrl, "OpenAI base URL must be set");
- var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder());
+ var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder(), responseErrorHandler);
OpenAiAudioTranscriptionClient openAiChatClient = new OpenAiAudioTranscriptionClient(openAiAudioApi,
- transcriptionProperties.getOptions());
+ transcriptionProperties.getOptions(), retryTemplate);
return openAiChatClient;
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java
new file mode 100644
index 00000000000..4daf371fcec
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.retry;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.springframework.ai.retry.NonTransientAiException;
+import org.springframework.ai.retry.TransientAiException;
+import org.springframework.boot.autoconfigure.AutoConfiguration;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.annotation.Bean;
+import org.springframework.http.client.ClientHttpResponse;
+import org.springframework.lang.NonNull;
+import org.springframework.retry.RetryCallback;
+import org.springframework.retry.RetryContext;
+import org.springframework.retry.RetryListener;
+import org.springframework.retry.support.RetryTemplate;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StreamUtils;
+import org.springframework.web.client.ResponseErrorHandler;
+
+/**
+ * @author Christian Tzolov
+ */
+@AutoConfiguration
+@ConditionalOnClass(RetryTemplate.class)
+@EnableConfigurationProperties({ SpringAiRetryProperties.class })
+public class SpringAiRetryAutoConfiguration {
+
+ private static final Logger logger = LoggerFactory.getLogger(SpringAiRetryAutoConfiguration.class);
+
+ @Bean
+ @ConditionalOnMissingBean
+ public RetryTemplate retryTemplate(SpringAiRetryProperties properties) {
+ return RetryTemplate.builder()
+ .maxAttempts(properties.getMaxAttempts())
+ .retryOn(NonTransientAiException.class)
+ .exponentialBackoff(properties.getBackoff().getInitialInterval(), properties.getBackoff().getMultiplier(),
+ properties.getBackoff().getMaxInterval())
+ .withListener(new RetryListener() {
+ @Override
+ public void onError(RetryContext context,
+ RetryCallback callback, Throwable throwable) {
+ logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
+ };
+ })
+ .build();
+ }
+
+ @Bean
+ @ConditionalOnMissingBean
+ public ResponseErrorHandler responseErrorHandler(SpringAiRetryProperties properties) {
+
+ return new ResponseErrorHandler() {
+
+ @Override
+ public boolean hasError(@NonNull ClientHttpResponse response) throws IOException {
+ return response.getStatusCode().isError();
+ }
+
+ @Override
+ public void handleError(@NonNull ClientHttpResponse response) throws IOException {
+ if (response.getStatusCode().isError()) {
+ String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
+ String message = String.format("%s - %s", response.getStatusCode().value(), error);
+ /**
+ * Thrown on 4xx client errors, such as 401 - Incorrect API key
+ * provided, 401 - You must be a member of an organization to use the
+ * API, 429 - Rate limit reached for requests, 429 - You exceeded your
+ * current quota , please check your plan and billing details.
+ */
+ if (properties.isNoRetryOnHttpClientErrors() && response.getStatusCode().is4xxClientError()) {
+ throw new NonTransientAiException(message);
+ }
+ // Explicitly configured non-transient codes
+ if (!CollectionUtils.isEmpty(properties.getNoRetryOnHttpCodes())
+ && properties.getNoRetryOnHttpCodes().contains(response.getStatusCode().value())) {
+ throw new NonTransientAiException(message);
+ }
+ throw new TransientAiException(message);
+ }
+ }
+ };
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java
new file mode 100644
index 00000000000..f560d0ebf58
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java
@@ -0,0 +1,122 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.retry;
+
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+/**
+ * @author Christian Tzolov
+ */
+@ConfigurationProperties(SpringAiRetryProperties.CONFIG_PREFIX)
+public class SpringAiRetryProperties {
+
+ public static final String CONFIG_PREFIX = "spring.ai.retry";
+
+ /**
+ * Maximum number of retry attempts.
+ */
+ private int maxAttempts = 10;
+
+ /**
+ * Exponential Backoff properties.
+ */
+ @NestedConfigurationProperty
+ private Backoff backoff = new Backoff();
+
+ private boolean noRetryOnHttpClientErrors = true;
+
+ private List noRetryOnHttpCodes = new ArrayList<>();
+
+ /**
+ * Exponential Backoff properties.
+ */
+ public static class Backoff {
+
+ /**
+ * Initial sleep duration.
+ */
+ private Duration initialInterval = Duration.ofMillis(2000);
+
+ /**
+ * Backoff interval multiplier.
+ */
+ private int multiplier = 5;
+
+ /**
+ * Maximum backoff duration.
+ */
+ private Duration maxInterval = Duration.ofMillis(3 * 60000);
+
+ public Duration getInitialInterval() {
+ return initialInterval;
+ }
+
+ public void setInitialInterval(Duration initialInterval) {
+ this.initialInterval = initialInterval;
+ }
+
+ public int getMultiplier() {
+ return multiplier;
+ }
+
+ public void setMultiplier(int multiplier) {
+ this.multiplier = multiplier;
+ }
+
+ public Duration getMaxInterval() {
+ return maxInterval;
+ }
+
+ public void setMaxInterval(Duration maxInterval) {
+ this.maxInterval = maxInterval;
+ }
+
+ }
+
+ public int getMaxAttempts() {
+ return this.maxAttempts;
+ }
+
+ public void setMaxAttempts(int maxAttempts) {
+ this.maxAttempts = maxAttempts;
+ }
+
+ public Backoff getBackoff() {
+ return this.backoff;
+ }
+
+ public List getNoRetryOnHttpCodes() {
+ return this.noRetryOnHttpCodes;
+ }
+
+ public void setNoRetryOnHttpCodes(List noRetryOnHttpCodes) {
+ this.noRetryOnHttpCodes = noRetryOnHttpCodes;
+ }
+
+ public boolean isNoRetryOnHttpClientErrors() {
+ return this.noRetryOnHttpClientErrors;
+ }
+
+ public void setNoRetryOnHttpClientErrors(boolean noRetryOnHttpClientErrors) {
+ this.noRetryOnHttpClientErrors = noRetryOnHttpClientErrors;
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
index f2ced6509de..172f4785d0b 100644
--- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
+++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports
@@ -22,4 +22,4 @@ org.springframework.ai.autoconfigure.vectorstore.azure.AzureVectorStoreAutoConfi
org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateVectorStoreAutoConfiguration
org.springframework.ai.autoconfigure.vectorstore.neo4j.Neo4jVectorStoreAutoConfiguration
org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration
-
+org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java
index c3e56639c0d..25816e8c0e4 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java
@@ -24,6 +24,7 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import reactor.core.publisher.Flux;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -47,7 +48,8 @@ public class MistralAiAutoConfigurationIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY"))
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class));
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class));
@Test
void generate() {
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java
index 6963855c837..ee4cf9aa7ba 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java
@@ -16,6 +16,8 @@
package org.springframework.ai.autoconfigure.mistralai;
import org.junit.jupiter.api.Test;
+
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -33,8 +35,8 @@ public void embeddingProperties() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.mistralai.base-url=TEST_BASE_URL", "spring.ai.mistralai.api-key=abc123",
"spring.ai.mistralai.embedding.options.model=MODEL_XYZ")
- .withConfiguration(
- AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
.run(context -> {
var embeddingProperties = context.getBean(MistralAiEmbeddingProperties.class);
var connectionProperties = context.getBean(MistralAiCommonProperties.class);
@@ -55,8 +57,8 @@ public void embeddingOverrideConnectionProperties() {
new ApplicationContextRunner().withPropertyValues("spring.ai.mistralai.base-url=TEST_BASE_URL",
"spring.ai.mistralai.api-key=abc123", "spring.ai.mistralai.embedding.base-url=TEST_BASE_URL2",
"spring.ai.mistralai.embedding.api-key=456", "spring.ai.mistralai.embedding.options.model=MODEL_XYZ")
- .withConfiguration(
- AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
.run(context -> {
var embeddingProperties = context.getBean(MistralAiEmbeddingProperties.class);
var connectionProperties = context.getBean(MistralAiCommonProperties.class);
@@ -79,8 +81,8 @@ public void embeddingOptionsTest() {
"spring.ai.mistralai.embedding.options.model=MODEL_XYZ",
"spring.ai.mistralai.embedding.options.encodingFormat=MyEncodingFormat")
- .withConfiguration(
- AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
.run(context -> {
var connectionProperties = context.getBean(MistralAiCommonProperties.class);
var embeddingProperties = context.getBean(MistralAiEmbeddingProperties.class);
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java
index 0b7972bdfbc..d4259f3ee38 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java
@@ -26,6 +26,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -48,7 +49,8 @@ class PaymentStatusBeanIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY"))
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class))
.withUserConfiguration(Config.class);
@Test
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java
index 33a5e327c6c..15796f921b2 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java
@@ -26,6 +26,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -55,7 +56,8 @@ class PaymentStatusBeanOpenAiIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY"),
"spring.ai.openai.chat.base-url=https://api.mistral.ai")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.withUserConfiguration(Config.class);
@Test
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java
index d7a46ece4ee..58bd9e84cc6 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java
@@ -26,6 +26,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -46,7 +47,8 @@ public class PaymentStatusPromptIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY"))
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class));
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class));
public record Transaction(@JsonProperty(required = true, value = "transaction_id") String id) {
}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java
index 915c378f008..6bf052f5697 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java
@@ -29,6 +29,7 @@
import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration;
import org.springframework.ai.autoconfigure.mistralai.tool.WeatherServicePromptIT.MyWeatherService.Request;
import org.springframework.ai.autoconfigure.mistralai.tool.WeatherServicePromptIT.MyWeatherService.Response;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -54,7 +55,8 @@ public class WeatherServicePromptIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.mistralai.api-key=" + System.getenv("MISTRAL_AI_API_KEY"))
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class));
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class));
@Test
void promptFunctionCall() {
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java
index 2561279f91f..41bbfb8c7b1 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java
@@ -31,6 +31,7 @@
import org.springframework.core.io.Resource;
import reactor.core.publisher.Flux;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.openai.OpenAiAudioTranscriptionClient;
@@ -49,7 +50,8 @@ public class OpenAiAutoConfigurationIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"))
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class));
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class));
@Test
void generate() {
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java
index 0eb67dd6c7b..e14a6a562de 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java
@@ -19,6 +19,7 @@
import org.skyscreamer.jsonassert.JSONAssert;
import org.skyscreamer.jsonassert.JSONCompareMode;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiEmbeddingClient;
import org.springframework.ai.openai.OpenAiImageClient;
@@ -52,7 +53,8 @@ public void chatProperties() {
"spring.ai.openai.chat.options.model=MODEL_XYZ",
"spring.ai.openai.chat.options.temperature=0.55")
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var chatProperties = context.getBean(OpenAiChatProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -78,7 +80,8 @@ public void transcriptionProperties() {
"spring.ai.openai.audio.transcription.options.model=MODEL_XYZ",
"spring.ai.openai.audio.transcription.options.temperature=0.55")
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var transcriptionProperties = context.getBean(OpenAiAudioTranscriptionProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -106,7 +109,8 @@ public void chatOverrideConnectionProperties() {
"spring.ai.openai.chat.options.model=MODEL_XYZ",
"spring.ai.openai.chat.options.temperature=0.55")
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var chatProperties = context.getBean(OpenAiChatProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -134,7 +138,8 @@ public void transcriptionOverrideConnectionProperties() {
"spring.ai.openai.audio.transcription.options.model=MODEL_XYZ",
"spring.ai.openai.audio.transcription.options.temperature=0.55")
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var transcriptionProperties = context.getBean(OpenAiAudioTranscriptionProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -159,7 +164,8 @@ public void embeddingProperties() {
"spring.ai.openai.api-key=abc123",
"spring.ai.openai.embedding.options.model=MODEL_XYZ")
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -185,7 +191,8 @@ public void embeddingOverrideConnectionProperties() {
"spring.ai.openai.embedding.api-key=456",
"spring.ai.openai.embedding.options.model=MODEL_XYZ")
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -209,7 +216,8 @@ public void imageProperties() {
"spring.ai.openai.image.options.model=MODEL_XYZ",
"spring.ai.openai.image.options.n=3")
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var imageProperties = context.getBean(OpenAiImageProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -236,7 +244,8 @@ public void imageOverrideConnectionProperties() {
"spring.ai.openai.image.options.model=MODEL_XYZ",
"spring.ai.openai.image.options.n=3")
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var imageProperties = context.getBean(OpenAiImageProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -304,7 +313,8 @@ public void chatOptionsTest() {
"spring.ai.openai.chat.options.user=userXYZ"
)
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var chatProperties = context.getBean(OpenAiChatProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -357,7 +367,8 @@ public void transcriptionOptionsTest() {
"spring.ai.openai.audio.transcription.options.temperature=0.55"
)
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var transcriptionProperties = context.getBean(OpenAiAudioTranscriptionProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -390,7 +401,8 @@ public void embeddingOptionsTest() {
"spring.ai.openai.embedding.options.user=userXYZ"
)
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class);
@@ -422,7 +434,8 @@ public void imageOptionsTest() {
"spring.ai.openai.image.options.user=userXYZ"
)
// @formatter:on
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
var imageProperties = context.getBean(OpenAiImageProperties.class);
var connectionProperties = context.getBean(OpenAiConnectionProperties.class);
@@ -448,7 +461,8 @@ void embeddingActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.embedding.enabled=false")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiEmbeddingProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiEmbeddingClient.class)).isEmpty();
@@ -456,7 +470,8 @@ void embeddingActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiEmbeddingProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiEmbeddingClient.class)).isNotEmpty();
@@ -465,7 +480,8 @@ void embeddingActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.embedding.enabled=true")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiEmbeddingProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiEmbeddingClient.class)).isNotEmpty();
@@ -477,7 +493,8 @@ void chatActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.chat.enabled=false")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiChatProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiChatClient.class)).isEmpty();
@@ -485,7 +502,8 @@ void chatActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiChatProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiChatClient.class)).isNotEmpty();
@@ -494,7 +512,8 @@ void chatActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.chat.enabled=true")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiChatProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiChatClient.class)).isNotEmpty();
@@ -507,7 +526,8 @@ void imageActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.image.enabled=false")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiImageProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiImageClient.class)).isEmpty();
@@ -515,7 +535,8 @@ void imageActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiImageProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiImageClient.class)).isNotEmpty();
@@ -524,7 +545,8 @@ void imageActivation() {
new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL",
"spring.ai.openai.image.enabled=true")
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.run(context -> {
assertThat(context.getBeansOfType(OpenAiImageProperties.class)).isNotEmpty();
assertThat(context.getBeansOfType(OpenAiImageClient.class)).isNotEmpty();
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java
index a5c0574b448..09e17abe5e5 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java
@@ -23,6 +23,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -42,7 +43,8 @@ public class FunctionCallbackInPromptIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"))
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class));
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class));
@Test
void functionCallTest() {
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java
index 9aa664925e2..51bd8721dd8 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java
@@ -24,6 +24,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -47,7 +48,8 @@ class FunctionCallbackWithPlainFunctionBeanIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"))
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.withUserConfiguration(Config.class);
@Test
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java
index df1f6a33276..9a4e70072dd 100644
--- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java
@@ -23,6 +23,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
+import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
@@ -45,7 +46,8 @@ public class FunctionCallbackWrapperIT {
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY"))
- .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
+ RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class))
.withUserConfiguration(Config.class);
@Test
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java
new file mode 100644
index 00000000000..f15e605fc50
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java
@@ -0,0 +1,44 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.retry;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+import org.springframework.retry.support.RetryTemplate;
+import org.springframework.web.client.ResponseErrorHandler;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * @author Christian Tzolov
+ */
+public class SpringAiRetryAutoConfigurationIT {
+
+ private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withConfiguration(
+ AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class));
+
+ @Test
+ void testRetryAutoConfiguration() {
+ this.contextRunner.run((context) -> {
+ assertThat(context).hasSingleBean(RetryTemplate.class);
+ assertThat(context).hasSingleBean(ResponseErrorHandler.class);
+ });
+ }
+
+}
diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java
new file mode 100644
index 00000000000..ebf549276d7
--- /dev/null
+++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2023 - 2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.ai.autoconfigure.retry;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.boot.autoconfigure.AutoConfigurations;
+import org.springframework.boot.test.context.runner.ApplicationContextRunner;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Unit Tests for {@link SpringAiRetryProperties}.
+ *
+ * @author Christian Tzolov
+ */
+public class SpringAiRetryPropertiesTests {
+
+ @Test
+ public void retryDefaultProperties() {
+
+ new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class))
+ .run(context -> {
+ var retryProperties = context.getBean(SpringAiRetryProperties.class);
+
+ assertThat(retryProperties.getMaxAttempts()).isEqualTo(10);
+ assertThat(retryProperties.isNoRetryOnHttpClientErrors()).isTrue();
+ assertThat(retryProperties.getNoRetryOnHttpCodes()).isEmpty();
+ assertThat(retryProperties.getBackoff().getInitialInterval().toMillis()).isEqualTo(2000);
+ assertThat(retryProperties.getBackoff().getMultiplier()).isEqualTo(5);
+ assertThat(retryProperties.getBackoff().getMaxInterval().toMillis()).isEqualTo(3 * 60000);
+ });
+ }
+
+ @Test
+ public void retryCustomProperties() {
+
+ new ApplicationContextRunner().withPropertyValues(
+ // @formatter:off
+ "spring.ai.retry.max-attempts=100",
+ "spring.ai.retry.no-retry-on-http-client-errors=false",
+ "spring.ai.retry.no-retry-on-http-codes=404,500",
+ "spring.ai.retry.backoff.initial-interval=1000",
+ "spring.ai.retry.backoff.multiplier=2",
+ "spring.ai.retry.backoff.max-interval=60000" )
+ // @formatter:on
+ .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class))
+ .run(context -> {
+ var retryProperties = context.getBean(SpringAiRetryProperties.class);
+
+ assertThat(retryProperties.getMaxAttempts()).isEqualTo(100);
+ assertThat(retryProperties.isNoRetryOnHttpClientErrors()).isFalse();
+ assertThat(retryProperties.getNoRetryOnHttpCodes()).containsExactly(404, 500);
+ assertThat(retryProperties.getBackoff().getInitialInterval().toMillis()).isEqualTo(1000);
+ assertThat(retryProperties.getBackoff().getMultiplier()).isEqualTo(2);
+ assertThat(retryProperties.getBackoff().getMaxInterval().toMillis()).isEqualTo(60000);
+ });
+ }
+
+}