From c545dd2833bebbf79dc0e7c22eb6102764cfe644 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 6 Mar 2024 11:53:24 +0100 Subject: [PATCH] Refactor and centralize Retry logic: - Establish a new "spring-ai-retry" project, implementing a default HTTP error handler, RetryTemplate, and handling both Transient and Non-Transient Exceptions. - Streamline existing clients (e.g., OpenAI and MistralAI) to utilize "spring-ai-retry." - Integrate retry auto-configuration with customizable properties, extending it to OpenAI and MistralAI Auto-Configs. - Allow configuration of RetryTemplate and ResponseErrorHandler for various clients, including OpenAIChatClient, OpenAiEmbeddingClient, OpenAiAudioTranscriptionCline, OpenAiImageClient, MistralAiChatClient, and MistralAiEmbeddingClient. - Add tests for default RestTemplate and ResponseErrorHandler configurations in OpenAI and MistralAI. - Introduce new retry auto-config properties: "noRetryOnHttpClientErrors" and "noRetryOnHttpCodes." - Implement tests for retry auto-config properties. - Generate missing license headers. --- models/spring-ai-mistral-ai/pom.xml | 21 +- .../ai/mistralai/MistralAiChatClient.java | 57 ++-- .../mistralai/MistralAiEmbeddingClient.java | 39 ++- .../ai/mistralai/api/MistralAiApi.java | 67 +---- .../ai/mistralai/RetryTests.java | 192 ++++++++++++ models/spring-ai-openai/pom.xml | 11 +- .../OpenAiAudioTranscriptionClient.java | 18 +- .../ai/openai/OpenAiChatClient.java | 31 +- .../ai/openai/OpenAiEmbeddingClient.java | 26 +- .../ai/openai/OpenAiImageClient.java | 34 +-- .../ai/openai/api/ApiUtils.java | 37 +++ .../ai/openai/api/OpenAiApi.java | 30 +- .../ai/openai/api/OpenAiAudioApi.java | 11 +- .../ai/openai/api/OpenAiImageApi.java | 23 +- .../ai/openai/ChatCompletionRequestTests.java | 2 +- .../ai/openai/{chat => }/api/OpenAiApiIT.java | 2 +- .../api/RestClientBuilderTests.java | 3 +- .../api/tool/MockWeatherService.java | 2 +- .../api/tool/OpenAiApiToolFunctionCallIT.java | 3 +- ...ithTranscriptionResponseMetadataTests.java | 3 +- .../ai/openai/chat/OpenAiChatClientIT.java | 2 +- .../ai/openai/chat/RetryTests.java | 273 ++++++++++++++++++ models/spring-ai-stability-ai/pom.xml | 7 +- .../stabilityai/StabilityAiImageClient.java | 30 +- .../StabilityAiImageGenerationMetadata.java | 11 +- .../ai/stabilityai/StyleEnum.java | 22 +- .../ai/stabilityai/api/StabilityAiApi.java | 37 +-- pom.xml | 4 + spring-ai-bom/pom.xml | 18 +- .../pages/api/clients/image/openai-image.adoc | 13 + .../pages/api/clients/mistralai-chat.adoc | 12 + .../ROOT/pages/api/clients/openai-chat.adoc | 12 + .../api/embeddings/openai-embeddings.adoc | 12 + spring-ai-retry/pom.xml | 50 ++++ .../ai/retry/NonTransientAiException.java | 35 +++ .../springframework/ai/retry/RetryUtils.java | 47 ++- .../ai/retry/TransientAiException.java | 35 +++ .../mistralai/MistralAiAutoConfiguration.java | 25 +- .../openai/OpenAiAutoConfiguration.java | 38 ++- .../retry/SpringAiRetryAutoConfiguration.java | 105 +++++++ .../retry/SpringAiRetryProperties.java | 122 ++++++++ ...ot.autoconfigure.AutoConfiguration.imports | 2 +- .../MistralAiAutoConfigurationIT.java | 4 +- .../mistralai/MistralAiPropertiesTests.java | 14 +- .../mistralai/tool/PaymentStatusBeanIT.java | 4 +- .../tool/PaymentStatusBeanOpenAiIT.java | 4 +- .../mistralai/tool/PaymentStatusPromptIT.java | 4 +- .../tool/WeatherServicePromptIT.java | 4 +- .../openai/OpenAiAutoConfigurationIT.java | 4 +- .../openai/OpenAiPropertiesTests.java | 64 ++-- .../tool/FunctionCallbackInPromptIT.java | 4 +- ...nctionCallbackWithPlainFunctionBeanIT.java | 4 +- .../tool/FunctionCallbackWrapperIT.java | 4 +- .../SpringAiRetryAutoConfigurationIT.java | 44 +++ .../retry/SpringAiRetryPropertiesTests.java | 73 +++++ 55 files changed, 1371 insertions(+), 384 deletions(-) create mode 100644 models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/RetryTests.java create mode 100644 models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ApiUtils.java rename models/spring-ai-openai/src/test/java/org/springframework/ai/openai/{chat => }/api/OpenAiApiIT.java (98%) rename models/spring-ai-openai/src/test/java/org/springframework/ai/openai/{chat => }/api/RestClientBuilderTests.java (96%) rename models/spring-ai-openai/src/test/java/org/springframework/ai/openai/{chat => }/api/tool/MockWeatherService.java (97%) rename models/spring-ai-openai/src/test/java/org/springframework/ai/openai/{chat => }/api/tool/OpenAiApiToolFunctionCallIT.java (99%) create mode 100644 models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/RetryTests.java create mode 100644 spring-ai-retry/pom.xml create mode 100644 spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java rename models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/ApiUtils.java => spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java (53%) create mode 100644 spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java diff --git a/models/spring-ai-mistral-ai/pom.xml b/models/spring-ai-mistral-ai/pom.xml index a613e4f6a1c..07fd0e087ef 100644 --- a/models/spring-ai-mistral-ai/pom.xml +++ b/models/spring-ai-mistral-ai/pom.xml @@ -29,24 +29,13 @@ ${project.parent.version} - - org.springframework - spring-web - ${spring-framework.version} - - - - org.springframework.retry - spring-retry - 2.0.4 - - + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + - - org.springframework - spring-webflux - org.springframework spring-context-support diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java index 73bffce3e49..ee212817d35 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java @@ -15,7 +15,6 @@ */ package org.springframework.ai.mistralai; -import java.time.Duration; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -41,10 +40,8 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -70,17 +67,7 @@ public class MistralAiChatClient extends */ private final MistralAiApi mistralAiApi; - private final RetryTemplate retryTemplate = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(MistralAiApi.MistralAiApiException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) - .withListener(new RetryListener() { - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - log.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; - }) - .build(); + private final RetryTemplate retryTemplate; public MistralAiChatClient(MistralAiApi mistralAiApi) { this(mistralAiApi, @@ -93,46 +80,50 @@ public MistralAiChatClient(MistralAiApi mistralAiApi) { } public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options) { - this(mistralAiApi, options, null); + this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); } public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options, - FunctionCallbackContext functionCallbackContext) { + FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) { super(functionCallbackContext); Assert.notNull(mistralAiApi, "MistralAiApi must not be null"); Assert.notNull(options, "Options must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); this.mistralAiApi = mistralAiApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override public ChatResponse call(Prompt prompt) { - // return retryTemplate.execute(ctx -> { var request = createRequest(prompt, false); - // var completionEntity = this.mistralAiApi.chatCompletionEntity(request); - ResponseEntity completionEntity = this.callWithFunctionSupport(request); + return retryTemplate.execute(ctx -> { - var chatCompletion = completionEntity.getBody(); - if (chatCompletion == null) { - log.warn("No chat completion returned for prompt: {}", prompt); - return new ChatResponse(List.of()); - } + ResponseEntity completionEntity = this.callWithFunctionSupport(request); - List generations = chatCompletion.choices() - .stream() - .map(choice -> new Generation(choice.message().content(), Map.of("role", choice.message().role().name())) - .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null))) - .toList(); + var chatCompletion = completionEntity.getBody(); + if (chatCompletion == null) { + log.warn("No chat completion returned for prompt: {}", prompt); + return new ChatResponse(List.of()); + } + + List generations = chatCompletion.choices() + .stream() + .map(choice -> new Generation(choice.message().content(), + Map.of("role", choice.message().role().name())) + .withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null))) + .toList(); - return new ChatResponse(generations); - // }); + return new ChatResponse(generations); + }); } @Override public Flux stream(Prompt prompt) { + var request = createRequest(prompt, true); + return retryTemplate.execute(ctx -> { - var request = createRequest(prompt, true); var completionChunks = this.mistralAiApi.chatCompletionStream(request); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java index 30e17b36f38..e42908c348d 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingClient.java @@ -15,23 +15,25 @@ */ package org.springframework.ai.mistralai; +import java.util.List; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.document.Document; import org.springframework.ai.document.MetadataMode; -import org.springframework.ai.embedding.*; +import org.springframework.ai.embedding.AbstractEmbeddingClient; +import org.springframework.ai.embedding.Embedding; +import org.springframework.ai.embedding.EmbeddingOptions; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; import org.springframework.ai.mistralai.api.MistralAiApi; -import org.springframework.ai.mistralai.api.MistralAiApi.MistralAiApiException; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; +import org.springframework.ai.retry.RetryUtils; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; -import java.time.Duration; -import java.util.List; - /** * @author Ricken Bazolo * @since 0.8.1 @@ -46,17 +48,7 @@ public class MistralAiEmbeddingClient extends AbstractEmbeddingClient { private final MistralAiApi mistralAiApi; - private final RetryTemplate retryTemplate = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(MistralAiApiException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) - .withListener(new RetryListener() { - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - log.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; - }) - .build(); + private final RetryTemplate retryTemplate; public MistralAiEmbeddingClient(MistralAiApi mistralAiApi) { this(mistralAiApi, MetadataMode.EMBED); @@ -64,22 +56,25 @@ public MistralAiEmbeddingClient(MistralAiApi mistralAiApi) { public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode) { this(mistralAiApi, metadataMode, - MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build()); + MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); } public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MistralAiEmbeddingOptions options) { - this(mistralAiApi, MetadataMode.EMBED, options); + this(mistralAiApi, MetadataMode.EMBED, options, RetryUtils.DEFAULT_RETRY_TEMPLATE); } public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode, - MistralAiEmbeddingOptions options) { + MistralAiEmbeddingOptions options, RetryTemplate retryTemplate) { Assert.notNull(mistralAiApi, "MistralAiApi must not be null"); Assert.notNull(metadataMode, "metadataMode must not be null"); Assert.notNull(options, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); this.mistralAiApi = mistralAiApi; this.metadataMode = metadataMode; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java index f227e5f521a..d2bd75c4a7c 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/api/MistralAiApi.java @@ -15,8 +15,6 @@ */ package org.springframework.ai.mistralai.api; -import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -25,22 +23,18 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.retry.RetryUtils; import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.http.client.ClientHttpResponse; -import org.springframework.lang.NonNull; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.StreamUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -70,8 +64,6 @@ public class MistralAiApi { private WebClient webClient; - private final ObjectMapper objectMapper; - /** * Create a new client api with DEFAULT_BASE_URL * @param mistralAiApiKey Mistral api Key. @@ -86,7 +78,7 @@ public MistralAiApi(String mistralAiApiKey) { * @param mistralAiApiKey Mistral api Key. */ public MistralAiApi(String baseUrl, String mistralAiApiKey) { - this(baseUrl, mistralAiApiKey, RestClient.builder()); + this(baseUrl, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); } /** @@ -94,67 +86,22 @@ public MistralAiApi(String baseUrl, String mistralAiApiKey) { * @param baseUrl api base URL. * @param mistralAiApiKey Mistral api Key. * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. */ - public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder) { - - this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { Consumer jsonContentHeaders = headers -> { headers.setBearerAuth(mistralAiApiKey); headers.setContentType(MediaType.APPLICATION_JSON); }; - var responseErrorHandler = new ResponseErrorHandler() { - - @Override - public boolean hasError(@NonNull ClientHttpResponse response) throws IOException { - return response.getStatusCode().isError(); - } - - @Override - public void handleError(@NonNull ClientHttpResponse response) throws IOException { - if (response.getStatusCode().isError()) { - String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8); - String message = String.format("%s - %s", response.getStatusCode().value(), error); - if (response.getStatusCode().is4xxClientError()) { - throw new MistralAiApiClientErrorException(message); - } - throw new MistralAiApiException(message); - } - } - }; - this.restClient = restClientBuilder.baseUrl(baseUrl) .defaultHeaders(jsonContentHeaders) .defaultStatusHandler(responseErrorHandler) .build(); - this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); - } - - public static class MistralAiApiException extends RuntimeException { - - public MistralAiApiException(String message) { - super(message); - } - - public MistralAiApiException(String message, Throwable t) { - super(message, t); - } - - } - - /** - * Thrown on 4xx client errors, such as 401 - Incorrect API key provided, 401 - You - * must be a member of an organization to use the API, 429 - Rate limit reached for - * requests, 429 - You exceeded your current quota , please check your plan and - * billing details. - */ - public static class MistralAiApiClientErrorException extends RuntimeException { - - public MistralAiApiClientErrorException(String message) { - super(message); - } + this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build(); } /** @@ -594,7 +541,7 @@ public enum ChatCompletionFinishReason { // anticipation of future changes. Based on: // https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py @JsonProperty("error") ERROR, - + @JsonProperty("tool_calls") TOOL_CALLS // @formatter:on diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/RetryTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/RetryTests.java new file mode 100644 index 00000000000..4c6a93a900e --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/RetryTests.java @@ -0,0 +1,192 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.mistralai; + +import java.util.List; +import java.util.Optional; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.mistralai.api.MistralAiApi; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletion; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionChunk; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionFinishReason; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; +import org.springframework.ai.mistralai.api.MistralAiApi.Embedding; +import org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingList; +import org.springframework.ai.mistralai.api.MistralAiApi.EmbeddingRequest; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; + +/** + * @author Christian Tzolov + */ +@SuppressWarnings("unchecked") +@ExtendWith(MockitoExtension.class) +public class RetryTests { + + private class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + onErrorRetryCount = context.getRetryCount(); + } + + } + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + private @Mock MistralAiApi mistralAiApi; + + private MistralAiChatClient chatClient; + + private MistralAiEmbeddingClient embeddingClient; + + @BeforeEach + public void beforeEach() { + retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + retryListener = new TestRetryListener(); + retryTemplate.registerListener(retryListener); + + chatClient = new MistralAiChatClient(mistralAiApi, + MistralAiChatOptions.builder() + .withTemperature(0.7f) + .withTopP(1f) + .withSafePrompt(false) + .withModel(MistralAiApi.ChatModel.TINY.getValue()) + .build(), + null, retryTemplate); + embeddingClient = new MistralAiEmbeddingClient(mistralAiApi, MetadataMode.EMBED, + MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(), + retryTemplate); + } + + @Test + public void mistralAiChatTransientError() { + + var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), + ChatCompletionFinishReason.STOP); + ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789l, "model", + List.of(choice), new MistralAiApi.Usage(10, 10, 10)); + + when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = chatClient.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void mistralAiChatNonTransientError() { + when(mistralAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text"))); + } + + @Test + public void mistralAiChatStreamTransientError() { + + var choice = new ChatCompletionChunk.ChunkChoice(0, new ChatCompletionMessage("Response", Role.ASSISTANT), + ChatCompletionFinishReason.STOP); + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", "chat.completion.chunk", 789l, + "model", List.of(choice)); + + when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(Flux.just(expectedChatCompletion)); + + var result = chatClient.stream(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void mistralAiChatStreamNonTransientError() { + when(mistralAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text"))); + } + + @Test + public void mistralAiEmbeddingTransientError() { + + EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", + List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new MistralAiApi.Usage(10, 10, 10)); + + when(mistralAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + + var result = embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8)); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void mistralAiEmbeddingNonTransientError() { + when(mistralAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); + } + +} diff --git a/models/spring-ai-openai/pom.xml b/models/spring-ai-openai/pom.xml index 7fbfce8ea56..10be9e4c14f 100644 --- a/models/spring-ai-openai/pom.xml +++ b/models/spring-ai-openai/pom.xml @@ -29,9 +29,9 @@ - org.springframework.retry - spring-retry - 2.0.4 + org.springframework.ai + spring-ai-retry + ${project.parent.version} @@ -57,11 +57,6 @@ ${victools.version} - - org.springframework - spring-webflux - - org.springframework spring-context-support diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java index 68e24808de4..30982ec12af 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioTranscriptionClient.java @@ -31,8 +31,6 @@ package org.springframework.ai.openai; -import java.time.Duration; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,12 +38,12 @@ import org.springframework.ai.model.ModelClient; import org.springframework.ai.openai.api.OpenAiAudioApi; import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse; -import org.springframework.ai.openai.api.common.OpenAiApiException; import org.springframework.ai.openai.audio.transcription.AudioTranscription; import org.springframework.ai.openai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.openai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; +import org.springframework.ai.retry.RetryUtils; import org.springframework.core.io.Resource; import org.springframework.http.ResponseEntity; import org.springframework.retry.support.RetryTemplate; @@ -66,11 +64,7 @@ public class OpenAiAudioTranscriptionClient private final OpenAiAudioTranscriptionOptions defaultOptions; - public final RetryTemplate retryTemplate = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(OpenAiApiException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) - .build(); + public final RetryTemplate retryTemplate; private final OpenAiAudioApi audioApi; @@ -80,14 +74,18 @@ public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi) { .withModel(OpenAiAudioApi.WhisperModel.WHISPER_1.getValue()) .withResponseFormat(OpenAiAudioApi.TranscriptResponseFormat.JSON) .withTemperature(0.7f) - .build()); + .build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); } - public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options) { + public OpenAiAudioTranscriptionClient(OpenAiAudioApi audioApi, OpenAiAudioTranscriptionOptions options, + RetryTemplate retryTemplate) { Assert.notNull(audioApi, "OpenAiAudioApi must not be null"); Assert.notNull(options, "OpenAiTranscriptionOptions must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); this.audioApi = audioApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } public String call(Resource audioResource) { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index bd977a58d05..9e83ce6b07d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -15,7 +15,6 @@ */ package org.springframework.ai.openai; -import java.time.Duration; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -43,14 +42,11 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; -import org.springframework.ai.openai.api.common.OpenAiApiException; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.openai.metadata.OpenAiChatResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor; +import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -73,7 +69,7 @@ public class OpenAiChatClient extends AbstractFunctionCallSupport> implements ChatClient, StreamingChatClient { - private final Logger logger = LoggerFactory.getLogger(getClass()); + private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClient.class); /** * The default options used for the chat completion requests. @@ -83,18 +79,7 @@ public class OpenAiChatClient extends /** * The retry template used to retry the OpenAI API calls. */ - public final RetryTemplate retryTemplate = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(OpenAiApiException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) - .withListener(new RetryListener() { - @Override - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; - }) - .build(); + public final RetryTemplate retryTemplate; /** * Low-level access to the OpenAI API. @@ -107,24 +92,26 @@ public OpenAiChatClient(OpenAiApi openAiApi) { } public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options) { - this(openAiApi, options, null); + this(openAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE); } public OpenAiChatClient(OpenAiApi openAiApi, OpenAiChatOptions options, - FunctionCallbackContext functionCallbackContext) { + FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) { super(functionCallbackContext); Assert.notNull(openAiApi, "OpenAiApi must not be null"); Assert.notNull(options, "Options must not be null"); + Assert.notNull(retryTemplate, "RetryTemplate must not be null"); this.openAiApi = openAiApi; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override public ChatResponse call(Prompt prompt) { - return this.retryTemplate.execute(ctx -> { + ChatCompletionRequest request = createRequest(prompt, false); - ChatCompletionRequest request = createRequest(prompt, false); + return this.retryTemplate.execute(ctx -> { ResponseEntity completionEntity = this.callWithFunctionSupport(request); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java index 839828f9541..f5d5730d750 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingClient.java @@ -15,7 +15,6 @@ */ package org.springframework.ai.openai; -import java.time.Duration; import java.util.List; import org.slf4j.Logger; @@ -33,10 +32,7 @@ import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; import org.springframework.ai.openai.api.OpenAiApi.Usage; -import org.springframework.ai.openai.api.common.OpenAiApiException; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; +import org.springframework.ai.retry.RetryUtils; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -51,17 +47,7 @@ public class OpenAiEmbeddingClient extends AbstractEmbeddingClient { private final OpenAiEmbeddingOptions defaultOptions; - private final RetryTemplate retryTemplate = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(OpenAiApiException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) - .withListener(new RetryListener() { - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; - }) - .build(); + private final RetryTemplate retryTemplate; private final OpenAiApi openAiApi; @@ -73,17 +59,21 @@ public OpenAiEmbeddingClient(OpenAiApi openAiApi) { public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode) { this(openAiApi, metadataMode, - OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build()); + OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(), + RetryUtils.DEFAULT_RETRY_TEMPLATE); } - public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options) { + public OpenAiEmbeddingClient(OpenAiApi openAiApi, MetadataMode metadataMode, OpenAiEmbeddingOptions options, + RetryTemplate retryTemplate) { Assert.notNull(openAiApi, "OpenAiService must not be null"); Assert.notNull(metadataMode, "metadataMode must not be null"); Assert.notNull(options, "options must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); this.openAiApi = openAiApi; this.metadataMode = metadataMode; this.defaultOptions = options; + this.retryTemplate = retryTemplate; } @Override diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java index e9020302e38..1bc78d4c23d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageClient.java @@ -15,7 +15,6 @@ */ package org.springframework.ai.openai; -import java.time.Duration; import java.util.List; import org.slf4j.Logger; @@ -30,13 +29,10 @@ import org.springframework.ai.image.ImageResponseMetadata; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiImageApi; -import org.springframework.ai.openai.api.common.OpenAiApiException; import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; import org.springframework.ai.openai.metadata.OpenAiImageResponseMetadata; +import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; -import org.springframework.retry.RetryCallback; -import org.springframework.retry.RetryContext; -import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -50,38 +46,32 @@ */ public class OpenAiImageClient implements ImageClient { - private final Logger logger = LoggerFactory.getLogger(getClass()); + private final static Logger logger = LoggerFactory.getLogger(OpenAiImageClient.class); private OpenAiImageOptions defaultOptions; private final OpenAiImageApi openAiImageApi; - public final RetryTemplate retryTemplate = RetryTemplate.builder() - .maxAttempts(10) - .retryOn(OpenAiApiException.class) - .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) - .withListener(new RetryListener() { - public void onError(RetryContext context, - RetryCallback callback, Throwable throwable) { - logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); - }; - }) - .build(); + public final RetryTemplate retryTemplate; public OpenAiImageClient(OpenAiImageApi openAiImageApi) { + this(openAiImageApi, OpenAiImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE); + } + + public OpenAiImageClient(OpenAiImageApi openAiImageApi, OpenAiImageOptions defaultOptions, + RetryTemplate retryTemplate) { Assert.notNull(openAiImageApi, "OpenAiImageApi must not be null"); + Assert.notNull(defaultOptions, "defaultOptions must not be null"); + Assert.notNull(retryTemplate, "retryTemplate must not be null"); this.openAiImageApi = openAiImageApi; + this.defaultOptions = defaultOptions; + this.retryTemplate = retryTemplate; } public OpenAiImageOptions getDefaultOptions() { return this.defaultOptions; } - public OpenAiImageClient withDefaultOptions(OpenAiImageOptions defaultOptions) { - this.defaultOptions = defaultOptions; - return this; - } - @Override public ImageResponse call(ImagePrompt imagePrompt) { return this.retryTemplate.execute(ctx -> { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ApiUtils.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ApiUtils.java new file mode 100644 index 00000000000..36ad4b8f758 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ApiUtils.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.api; + +import java.util.function.Consumer; + +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; + +/** + * @author Christian Tzolov + */ +public class ApiUtils { + + public static final String DEFAULT_BASE_URL = "https://api.openai.com"; + + public static Consumer getJsonContentHeaders(String apiKey) { + return (headers) -> { + headers.setBearerAuth(apiKey); + headers.setContentType(MediaType.APPLICATION_JSON); + }; + }; + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 647f06c8d1b..795014139cd 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -26,14 +26,13 @@ import reactor.core.publisher.Mono; import org.springframework.ai.model.ModelOptionsUtils; -import org.springframework.ai.openai.api.common.ApiUtils; +import org.springframework.ai.retry.RetryUtils; import org.springframework.boot.context.properties.bind.ConstructorBinding; import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -53,8 +52,6 @@ public class OpenAiApi { private final RestClient restClient; - private final RestClient multipartRestClient; - private final WebClient webClient; /** @@ -84,20 +81,23 @@ public OpenAiApi(String baseUrl, String openAiToken) { * @param restClientBuilder RestClient builder. */ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) { + this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new chat completion api. + * + * @param baseUrl api base URL. + * @param openAiToken OpenAI apiKey. + * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. + */ + public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { this.restClient = restClientBuilder .baseUrl(baseUrl) .defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken)) - .defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER) - .build(); - - this.multipartRestClient = restClientBuilder - .baseUrl(baseUrl) - .defaultHeaders(multipartFormDataHeaders -> { - multipartFormDataHeaders.setBearerAuth(openAiToken); - multipartFormDataHeaders.setContentType(MediaType.MULTIPART_FORM_DATA); - }) - .defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER) + .defaultStatusHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER) .build(); this.webClient = WebClient.builder() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java index fcdd1fd4d13..65e5ca310af 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java @@ -21,12 +21,13 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.openai.api.common.ApiUtils; +import org.springframework.ai.retry.RetryUtils; import org.springframework.core.io.ByteArrayResource; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; /** @@ -45,7 +46,7 @@ public class OpenAiAudioApi { * @param openAiToken OpenAI apiKey. */ public OpenAiAudioApi(String openAiToken) { - this(ApiUtils.DEFAULT_BASE_URL, openAiToken, RestClient.builder()); + this(ApiUtils.DEFAULT_BASE_URL, openAiToken, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); } /** @@ -53,12 +54,14 @@ public OpenAiAudioApi(String openAiToken) { * @param baseUrl api base URL. * @param openAiToken OpenAI apiKey. * @param restClientBuilder RestClient builder. + * @param responseErrorHandler Response error handler. */ - public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) { + public OpenAiAudioApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(headers -> { headers.setBearerAuth(openAiToken); - }).defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER).build(); + }).defaultStatusHandler(responseErrorHandler).build(); } /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java index 6499f51174f..18b3d0bff4e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java @@ -20,9 +20,10 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.openai.api.common.ApiUtils; +import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; +import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; /** @@ -44,11 +45,29 @@ public OpenAiImageApi(String openAiToken) { this(ApiUtils.DEFAULT_BASE_URL, openAiToken, RestClient.builder()); } + /** + * Create a new OpenAI Image API with the provided base URL. + * @param baseUrl the base URL for the OpenAI API. + * @param openAiToken OpenAI apiKey. + * @param restClientBuilder the rest client builder to use. + */ public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) { + this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); + } + + /** + * Create a new OpenAI Image API with the provided base URL. + * @param baseUrl the base URL for the OpenAI API. + * @param openAiToken OpenAI apiKey. + * @param restClientBuilder the rest client builder to use. + * @param responseErrorHandler the response error handler to use. + */ + public OpenAiImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder, + ResponseErrorHandler responseErrorHandler) { this.restClient = restClientBuilder.baseUrl(baseUrl) .defaultHeaders(ApiUtils.getJsonContentHeaders(openAiToken)) - .defaultStatusHandler(ApiUtils.DEFAULT_RESPONSE_ERROR_HANDLER) + .defaultStatusHandler(responseErrorHandler) .build(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java index a3190312d7f..300cdfafbb8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/ChatCompletionRequestTests.java @@ -22,7 +22,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.chat.api.tool.MockWeatherService; +import org.springframework.ai.openai.api.tool.MockWeatherService; import static org.assertj.core.api.Assertions.assertThat; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java similarity index 98% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/OpenAiApiIT.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index e54b7e87792..182a11dcec2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.api; +package org.springframework.ai.openai.api; import java.util.List; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/RestClientBuilderTests.java similarity index 96% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/RestClientBuilderTests.java index 96c45d56bf6..ffdccf035fb 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/RestClientBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/RestClientBuilderTests.java @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.api; +package org.springframework.ai.openai.api; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.http.client.SimpleClientHttpRequestFactory; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java similarity index 97% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java index e1364756ac3..db41af1f0d4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/MockWeatherService.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/MockWeatherService.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.api.tool; +package org.springframework.ai.openai.api.tool; import java.util.function.Function; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java similarity index 99% rename from models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java rename to models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index b3ac24b09d9..4138a24a98c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -13,7 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.chat.api.tool; + +package org.springframework.ai.openai.api.tool; import java.util.ArrayList; import java.util.List; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java index 8d99b60fe13..3b2629714f8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientWithTranscriptionResponseMetadataTests.java @@ -26,6 +26,7 @@ import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionMetadata; import org.springframework.ai.openai.metadata.audio.OpenAiAudioTranscriptionResponseMetadata; import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; +import org.springframework.ai.retry.RetryUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.autoconfigure.web.client.RestClientTest; @@ -151,7 +152,7 @@ static class Config { @Bean public OpenAiAudioApi chatCompletionApi(RestClient.Builder builder) { - return new OpenAiAudioApi("", TEST_API_KEY, builder); + return new OpenAiAudioApi("", TEST_API_KEY, builder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER); } @Bean diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java index a5f754a2ce7..579502daff1 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java @@ -37,7 +37,7 @@ import org.springframework.ai.model.function.FunctionCallbackWrapper; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.OpenAiTestConfiguration; -import org.springframework.ai.openai.chat.api.tool.MockWeatherService; +import org.springframework.ai.openai.api.tool.MockWeatherService; import org.springframework.ai.openai.testutils.AbstractIT; import org.springframework.ai.parser.BeanOutputParser; import org.springframework.ai.parser.ListOutputParser; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/RetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/RetryTests.java new file mode 100644 index 00000000000..4abe2625c3d --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/RetryTests.java @@ -0,0 +1,273 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.openai.chat; + +import java.util.List; +import java.util.Optional; + +import groovyjarjarpicocli.CommandLine.TraceLevel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.image.ImageMessage; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.openai.OpenAiAudioTranscriptionClient; +import org.springframework.ai.openai.OpenAiAudioTranscriptionOptions; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.OpenAiEmbeddingClient; +import org.springframework.ai.openai.OpenAiEmbeddingOptions; +import org.springframework.ai.openai.OpenAiImageClient; +import org.springframework.ai.openai.OpenAiImageOptions; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionFinishReason; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; +import org.springframework.ai.openai.api.OpenAiApi.Embedding; +import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; +import org.springframework.ai.openai.api.OpenAiApi.EmbeddingRequest; +import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.ai.openai.api.OpenAiAudioApi.StructuredResponse; +import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptResponseFormat; +import org.springframework.ai.openai.api.OpenAiAudioApi.TranscriptionRequest; +import org.springframework.ai.openai.api.OpenAiImageApi; +import org.springframework.ai.openai.api.OpenAiImageApi.Data; +import org.springframework.ai.openai.api.OpenAiImageApi.OpenAiImageRequest; +import org.springframework.ai.openai.api.OpenAiImageApi.OpenAiImageResponse; +import org.springframework.ai.openai.audio.transcription.AudioTranscriptionPrompt; +import org.springframework.ai.openai.audio.transcription.AudioTranscriptionResponse; +import org.springframework.ai.retry.RetryUtils; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.ResponseEntity; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.when; + +/** + * @author Christian Tzolov + */ +@SuppressWarnings("unchecked") +@ExtendWith(MockitoExtension.class) +public class RetryTests { + + private class TestRetryListener implements RetryListener { + + int onErrorRetryCount = 0; + + int onSuccessRetryCount = 0; + + @Override + public void onSuccess(RetryContext context, RetryCallback callback, T result) { + onSuccessRetryCount = context.getRetryCount(); + } + + @Override + public void onError(RetryContext context, RetryCallback callback, + Throwable throwable) { + onErrorRetryCount = context.getRetryCount(); + } + + } + + private TestRetryListener retryListener; + + private RetryTemplate retryTemplate; + + private @Mock OpenAiApi openAiApi; + + private @Mock OpenAiAudioApi openAiAudioApi; + + private @Mock OpenAiImageApi openAiImageApi; + + private OpenAiChatClient chatClient; + + private OpenAiEmbeddingClient embeddingClient; + + private OpenAiAudioTranscriptionClient audioTranscriptionClient; + + private OpenAiImageClient imageClient; + + @BeforeEach + public void beforeEach() { + retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE; + retryListener = new TestRetryListener(); + retryTemplate.registerListener(retryListener); + + chatClient = new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder().build(), null, retryTemplate); + embeddingClient = new OpenAiEmbeddingClient(openAiApi, MetadataMode.EMBED, + OpenAiEmbeddingOptions.builder().build(), retryTemplate); + audioTranscriptionClient = new OpenAiAudioTranscriptionClient(openAiAudioApi, + OpenAiAudioTranscriptionOptions.builder() + .withModel("model") + .withResponseFormat(TranscriptResponseFormat.JSON) + .build(), + retryTemplate); + imageClient = new OpenAiImageClient(openAiImageApi, OpenAiImageOptions.builder().build(), retryTemplate); + } + + @Test + public void openAiChatTransientError() { + + var choice = new ChatCompletion.Choice(ChatCompletionFinishReason.STOP, 0, + new ChatCompletionMessage("Response", Role.ASSISTANT), null); + ChatCompletion expectedChatCompletion = new ChatCompletion("id", List.of(choice), 666l, "model", null, null, + new OpenAiApi.Usage(10, 10, 10)); + + when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedChatCompletion))); + + var result = chatClient.call(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getContent()).isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void openAiChatNonTransientError() { + when(openAiApi.chatCompletionEntity(isA(ChatCompletionRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatClient.call(new Prompt("text"))); + } + + @Test + public void openAiChatStreamTransientError() { + + var choice = new ChatCompletionChunk.ChunkChoice(ChatCompletionFinishReason.STOP, 0, + new ChatCompletionMessage("Response", Role.ASSISTANT), null); + ChatCompletionChunk expectedChatCompletion = new ChatCompletionChunk("id", List.of(choice), 666l, "model", null, + null); + + when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(Flux.just(expectedChatCompletion)); + + var result = chatClient.stream(new Prompt("text")); + + assertThat(result).isNotNull(); + assertThat(result.collectList().block().get(0).getResult().getOutput().getContent()).isSameAs("Response"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void openAiChatStreamNonTransientError() { + when(openAiApi.chatCompletionStream(isA(ChatCompletionRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> chatClient.stream(new Prompt("text"))); + } + + @Test + public void openAiEmbeddingTransientError() { + + EmbeddingList expectedEmbeddings = new EmbeddingList<>("list", + List.of(new Embedding(0, List.of(9.9, 8.8))), "model", new OpenAiApi.Usage(10, 10, 10)); + + when(openAiApi.embeddings(isA(EmbeddingRequest.class))).thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedEmbeddings))); + + var result = embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null)); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput()).isEqualTo(List.of(9.9, 8.8)); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void openAiEmbeddingNonTransientError() { + when(openAiApi.embeddings(isA(EmbeddingRequest.class))) + .thenThrow(new RuntimeException("Non Transient Error")); + assertThrows(RuntimeException.class, () -> embeddingClient + .call(new org.springframework.ai.embedding.EmbeddingRequest(List.of("text1", "text2"), null))); + } + + @Test + public void openAiAudioTranscriptionTransientError() { + + var expectedResponse = new StructuredResponse("nl", 6.7f, "Transcription Text", List.of(), List.of()); + + when(openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); + + AudioTranscriptionResponse result = audioTranscriptionClient + .call(new AudioTranscriptionPrompt(new ClassPathResource("speech/jfk.flac"))); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput()).isEqualTo(expectedResponse.text()); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void openAiAudioTranscriptionNonTransientError() { + when(openAiAudioApi.createTranscription(isA(TranscriptionRequest.class), isA(Class.class))) + .thenThrow(new RuntimeException("Transient Error 1")); + assertThrows(RuntimeException.class, () -> audioTranscriptionClient + .call(new AudioTranscriptionPrompt(new ClassPathResource("speech/jfk.flac")))); + } + + @Test + public void openAiImageTransientError() { + + var expectedResponse = new OpenAiImageResponse(678l, List.of(new Data("url678", "b64", "prompt"))); + + when(openAiImageApi.createImage(isA(OpenAiImageRequest.class))) + .thenThrow(new TransientAiException("Transient Error 1")) + .thenThrow(new TransientAiException("Transient Error 2")) + .thenReturn(ResponseEntity.of(Optional.of(expectedResponse))); + + var result = imageClient.call(new ImagePrompt(List.of(new ImageMessage("Image Message")))); + + assertThat(result).isNotNull(); + assertThat(result.getResult().getOutput().getUrl()).isEqualTo("url678"); + assertThat(retryListener.onSuccessRetryCount).isEqualTo(2); + assertThat(retryListener.onErrorRetryCount).isEqualTo(2); + } + + @Test + public void openAiImageNonTransientError() { + when(openAiImageApi.createImage(isA(OpenAiImageRequest.class))) + .thenThrow(new RuntimeException("Transient Error 1")); + assertThrows(RuntimeException.class, + () -> imageClient.call(new ImagePrompt(List.of(new ImageMessage("Image Message"))))); + } + +} diff --git a/models/spring-ai-stability-ai/pom.xml b/models/spring-ai-stability-ai/pom.xml index b14b51a17ec..b5180e6faf0 100644 --- a/models/spring-ai-stability-ai/pom.xml +++ b/models/spring-ai-stability-ai/pom.xml @@ -30,12 +30,11 @@ - org.springframework - spring-web - ${spring-framework.version} + org.springframework.ai + spring-ai-retry + ${project.parent.version} - org.springframework diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java index d39594998b5..e632e5235b5 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageClient.java @@ -15,17 +15,24 @@ */ package org.springframework.ai.stabilityai; +import java.util.List; +import java.util.stream.Collectors; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.image.*; + +import org.springframework.ai.image.Image; +import org.springframework.ai.image.ImageClient; +import org.springframework.ai.image.ImageGeneration; +import org.springframework.ai.image.ImageOptions; +import org.springframework.ai.image.ImagePrompt; +import org.springframework.ai.image.ImageResponse; +import org.springframework.ai.image.ImageResponseMetadata; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.stabilityai.api.StabilityAiApi; import org.springframework.ai.stabilityai.api.StabilityAiImageOptions; import org.springframework.util.Assert; -import java.util.List; -import java.util.stream.Collectors; - /** * StabilityAiImageClient is a class that implements the ImageClient interface. It * provides a client for calling the StabilityAI image generation API. @@ -50,7 +57,7 @@ public StabilityAiImageClient(StabilityAiApi stabilityAiApi, StabilityAiImageOpt } public StabilityAiImageOptions getOptions() { - return options; + return this.options; } /** @@ -159,17 +166,4 @@ private StabilityAiImageOptions convertOptions(ImageOptions runtimeOptions) { return builder.build(); } - private ImagePrompt createUpdatedPrompt(ImagePrompt prompt) { - ImageOptions runtimeImageModelOptions = prompt.getOptions(); - ImageOptionsBuilder imageOptionsBuilder = ImageOptionsBuilder.builder(); - - if (runtimeImageModelOptions != null) { - if (runtimeImageModelOptions.getModel() != null) { - imageOptionsBuilder.withModel(runtimeImageModelOptions.getModel()); - } - } - ImageOptions updatedImageModelOptions = imageOptionsBuilder.build(); - return new ImagePrompt(prompt.getInstructions(), updatedImageModelOptions); - } - } diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java index 680f4ee3620..648dae1a5d9 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StabilityAiImageGenerationMetadata.java @@ -35,16 +35,17 @@ public StabilityAiImageGenerationMetadata(String finishReason, Long seed) { } public String getFinishReason() { - return finishReason; + return this.finishReason; } public Long getSeed() { - return seed; + return this.seed; } @Override public String toString() { - return "StabilityAiImageGenerationMetadata{" + "finishReason='" + finishReason + '\'' + ", seed=" + seed + '}'; + return "StabilityAiImageGenerationMetadata{" + "finishReason='" + this.finishReason + '\'' + ", seed=" + + this.seed + '}'; } @Override @@ -53,12 +54,12 @@ public boolean equals(Object o) { return true; if (!(o instanceof StabilityAiImageGenerationMetadata that)) return false; - return Objects.equals(finishReason, that.finishReason) && Objects.equals(seed, that.seed); + return Objects.equals(this.finishReason, that.finishReason) && Objects.equals(this.seed, that.seed); } @Override public int hashCode() { - return Objects.hash(finishReason, seed); + return Objects.hash(this.finishReason, this.seed); } } diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java index 15c8645e481..e1d7c9efa5a 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/StyleEnum.java @@ -20,11 +20,25 @@ */ public enum StyleEnum { - THREE_D_MODEL("3d-model"), ANALOG_FILM("analog-film"), ANIME("anime"), CINEMATIC("cinematic"), - COMIC_BOOK("comic-book"), DIGITAL_ART("digital-art"), ENHANCE("enhance"), FANTASY_ART("fantasy-art"), - ISOMETRIC("isometric"), LINE_ART("line-art"), LOW_POLY("low-poly"), MODELING_COMPOUND("modeling-compound"), - NEON_PUNK("neon-punk"), ORIGAMI("origami"), PHOTOGRAPHIC("photographic"), PIXEL_ART("pixel-art"), + // @formatter:off + THREE_D_MODEL("3d-model"), + ANALOG_FILM("analog-film"), + ANIME("anime"), + CINEMATIC("cinematic"), + COMIC_BOOK("comic-book"), + DIGITAL_ART("digital-art"), + ENHANCE("enhance"), + FANTASY_ART("fantasy-art"), + ISOMETRIC("isometric"), + LINE_ART("line-art"), + LOW_POLY("low-poly"), + MODELING_COMPOUND("modeling-compound"), + NEON_PUNK("neon-punk"), + ORIGAMI("origami"), + PHOTOGRAPHIC("photographic"), + PIXEL_ART("pixel-art"), TILE_TEXTURE("tile-texture"); + // @formatter:on private final String text; diff --git a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java index 984098a1c8a..2ee5b2f7f24 100644 --- a/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java +++ b/models/spring-ai-stability-ai/src/main/java/org/springframework/ai/stabilityai/api/StabilityAiApi.java @@ -15,20 +15,18 @@ */ package org.springframework.ai.stabilityai.api; +import java.util.List; +import java.util.function.Consumer; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.ObjectMapper; + +import org.springframework.ai.retry.RetryUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; -import org.springframework.http.client.ClientHttpResponse; import org.springframework.util.Assert; -import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; -import java.io.IOException; -import java.util.List; -import java.util.function.Consumer; - /** * Represents the StabilityAI API. */ @@ -80,35 +78,12 @@ public StabilityAiApi(String apiKey, String model, String baseUrl, RestClient.Bu headers.setContentType(MediaType.APPLICATION_JSON); }; - ResponseErrorHandler responseErrorHandler = new ResponseErrorHandler() { - @Override - public boolean hasError(ClientHttpResponse response) throws IOException { - return response.getStatusCode().isError(); - } - - @Override - public void handleError(ClientHttpResponse response) throws IOException { - if (response.getStatusCode().isError()) { - throw new RuntimeException(String.format("%s - %s", response.getStatusCode().value(), - new ObjectMapper().readValue(response.getBody(), ResponseError.class))); - } - } - }; - this.restClient = restClientBuilder.baseUrl(baseUrl) .defaultHeaders(jsonContentHeaders) - .defaultStatusHandler(responseErrorHandler) + .defaultStatusHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER) .build(); } - @JsonInclude(JsonInclude.Include.NON_NULL) - public record ResponseError(@JsonProperty("id") String id, @JsonProperty("name") String name, - @JsonProperty("message") String message - - ) { - - } - @JsonInclude(JsonInclude.Include.NON_NULL) public record GenerateImageRequest(@JsonProperty("text_prompts") List textPrompts, @JsonProperty("height") Integer height, @JsonProperty("width") Integer width, diff --git a/pom.xml b/pom.xml index 57a62fd27e6..da31f6831f7 100644 --- a/pom.xml +++ b/pom.xml @@ -59,6 +59,7 @@ vector-stores/spring-ai-qdrant spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai + spring-ai-retry @@ -115,6 +116,7 @@ 1.17.0 26.33.0 1.7.1 + 2.0.5 3.25.2 @@ -297,6 +299,8 @@ limitations under the License. + **/.antlr/** + **/aot.factories **/.sdkmanrc **/*.adoc **/*.puml diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index e838b6c080e..4f604f768e7 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -33,12 +33,18 @@ ${project.version} - - - org.springframework.ai - spring-ai-pdf-document-reader - ${project.version} - + + org.springframework.ai + spring-ai-retry + ${project.parent.version} + + + + + org.springframework.ai + spring-ai-pdf-document-reader + ${project.version} + org.springframework.ai diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/image/openai-image.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/image/openai-image.adoc index 3ac1ccd5b71..46f0decd7f6 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/image/openai-image.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/image/openai-image.adoc @@ -41,6 +41,19 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man === Image Generation Properties +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI Chat client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +|==== + + The prefix `spring.ai.openai` is used as the property prefix that lets you connect to OpenAI. [cols="3,5,1"] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/mistralai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/mistralai-chat.adoc index faf9d3b9720..e540bae43f7 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/mistralai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/mistralai-chat.adoc @@ -49,6 +49,18 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man === Chat Properties +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI Chat client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +|==== + The prefix `spring.ai.mistralai` is used as the property prefix that lets you connect to OpenAI. [cols="3,5,1"] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc index 2d184853c39..55a767a7d1e 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/clients/openai-chat.adoc @@ -49,6 +49,18 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man === Chat Properties +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI Chat client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +|==== + The prefix `spring.ai.openai` is used as the property prefix that lets you connect to OpenAI. [cols="3,5,1"] diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc index 9a9d3c2e5b7..57f19e66b36 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/embeddings/openai-embeddings.adoc @@ -50,6 +50,18 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man === Embedding Properties +The prefix `spring.ai.retry` is used as the property prefix that lets you configure the retry mechanism for the OpenAI Chat client. + +[cols="3,5,1"] +|==== +| Property | Description | Default + +| spring.ai.retry.max-attempts | Maximum number of retry attempts. | 10 +| spring.ai.retry.backoff.initial-interval | Initial sleep duration for the exponential backoff policy. | 2 sec. +| spring.ai.retry.backoff.multiplier | Backoff interval multiplier. | 5 +| spring.ai.retry.backoff.max-interval | Maximum backoff duration. | 3 min. +|==== + The prefix `spring.ai.openai` is used as the property prefix that lets you connect to OpenAI. [cols="3,5,1"] diff --git a/spring-ai-retry/pom.xml b/spring-ai-retry/pom.xml new file mode 100644 index 00000000000..9c2327c9037 --- /dev/null +++ b/spring-ai-retry/pom.xml @@ -0,0 +1,50 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai + 0.8.1-SNAPSHOT + + spring-ai-retry + jar + Spring AI Retry + Spring AI utility project helping with remote call retry + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + org.springframework.retry + spring-retry + ${spring-retry.version} + + + + org.springframework + spring-webflux + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java new file mode 100644 index 00000000000..55dbe5e3d74 --- /dev/null +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/NonTransientAiException.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.retry; + +/** + * Root of the hierarchy of Model access exceptions that are considered non-transient - + * where a retry of the same operation would fail unless the cause of the Exception is + * corrected. + * + * @author Christian Tzolov + */ +public class NonTransientAiException extends RuntimeException { + + public NonTransientAiException(String message) { + super(message); + } + + public NonTransientAiException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/ApiUtils.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java similarity index 53% rename from models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/ApiUtils.java rename to spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java index 8219170532d..ec6b2d671fb 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/common/ApiUtils.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java @@ -13,32 +13,43 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.ai.openai.api.common; +package org.springframework.ai.retry; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.function.Consumer; +import java.time.Duration; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import org.springframework.http.HttpHeaders; -import org.springframework.http.MediaType; import org.springframework.http.client.ClientHttpResponse; import org.springframework.lang.NonNull; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.StreamUtils; import org.springframework.web.client.ResponseErrorHandler; /** - * @author Christian Tzolov + * */ -public class ApiUtils { +public class RetryUtils { - public static final String DEFAULT_BASE_URL = "https://api.openai.com"; + private static final Logger logger = LoggerFactory.getLogger(RetryUtils.class); - public static Consumer getJsonContentHeaders(String apiKey) { - return (headers) -> { - headers.setBearerAuth(apiKey); - headers.setContentType(MediaType.APPLICATION_JSON); - }; - }; + public static final RetryTemplate DEFAULT_RETRY_TEMPLATE = RetryTemplate.builder() + .maxAttempts(10) + .retryOn(TransientAiException.class) + .exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000)) + .withListener(new RetryListener() { + @Override + public void onError(RetryContext context, + RetryCallback callback, Throwable throwable) { + logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); + }; + }) + .build(); public static final ResponseErrorHandler DEFAULT_RESPONSE_ERROR_HANDLER = new ResponseErrorHandler() { @@ -52,10 +63,16 @@ public void handleError(@NonNull ClientHttpResponse response) throws IOException if (response.getStatusCode().isError()) { String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8); String message = String.format("%s - %s", response.getStatusCode().value(), error); + /** + * Thrown on 4xx client errors, such as 401 - Incorrect API key provided, + * 401 - You must be a member of an organization to use the API, 429 - + * Rate limit reached for requests, 429 - You exceeded your current quota + * , please check your plan and billing details. + */ if (response.getStatusCode().is4xxClientError()) { - throw new OpenAiApiClientErrorException(message); + throw new NonTransientAiException(message); } - throw new OpenAiApiException(message); + throw new TransientAiException(message); } } }; diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java new file mode 100644 index 00000000000..4da1960255f --- /dev/null +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/TransientAiException.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.retry; + +/** + * Root of the hierarchy of Model access exceptions that are considered transient - where + * a previously failed operation might be able to succeed when the operation is retried + * without any intervention by application-level functionality. + * + * @author Christian Tzolov + */ +public class TransientAiException extends RuntimeException { + + public TransientAiException(String message) { + super(message); + } + + public TransientAiException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java index a3f67f8864d..1328f3488ba 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfiguration.java @@ -17,6 +17,7 @@ import java.util.List; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.mistralai.MistralAiChatClient; import org.springframework.ai.mistralai.MistralAiEmbeddingClient; import org.springframework.ai.mistralai.api.MistralAiApi; @@ -30,9 +31,11 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; /** @@ -40,7 +43,7 @@ * @author Christian Tzolov * @since 0.8.1 */ -@AutoConfiguration(after = { RestClientAutoConfiguration.class }) +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) @EnableConfigurationProperties({ MistralAiEmbeddingProperties.class, MistralAiCommonProperties.class, MistralAiChatProperties.class }) @ConditionalOnClass(MistralAiApi.class) @@ -51,13 +54,15 @@ public class MistralAiAutoConfiguration { @ConditionalOnProperty(prefix = MistralAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public MistralAiEmbeddingClient mistralAiEmbeddingClient(MistralAiCommonProperties commonProperties, - MistralAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder) { + MistralAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { var mistralAiApi = mistralAiApi(embeddingProperties.getApiKey(), commonProperties.getApiKey(), - embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder); + embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder, + responseErrorHandler); return new MistralAiEmbeddingClient(mistralAiApi, embeddingProperties.getMetadataMode(), - embeddingProperties.getOptions()); + embeddingProperties.getOptions(), retryTemplate); } @Bean @@ -66,20 +71,22 @@ public MistralAiEmbeddingClient mistralAiEmbeddingClient(MistralAiCommonProperti matchIfMissing = true) public MistralAiChatClient mistralAiChatClient(MistralAiCommonProperties commonProperties, MistralAiChatProperties chatProperties, RestClient.Builder restClientBuilder, - List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) { + List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { var mistralAiApi = mistralAiApi(chatProperties.getApiKey(), commonProperties.getApiKey(), - chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder); + chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), restClientBuilder, responseErrorHandler); if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); } - return new MistralAiChatClient(mistralAiApi, chatProperties.getOptions(), functionCallbackContext); + return new MistralAiChatClient(mistralAiApi, chatProperties.getOptions(), functionCallbackContext, + retryTemplate); } private MistralAiApi mistralAiApi(String apiKey, String commonApiKey, String baseUrl, String commonBaseUrl, - RestClient.Builder restClientBuilder) { + RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { var resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; var resoledBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; @@ -87,7 +94,7 @@ private MistralAiApi mistralAiApi(String apiKey, String commonApiKey, String bas Assert.hasText(resolvedApiKey, "Mistral API key must be set"); Assert.hasText(resoledBaseUrl, "Mistral base URL must be set"); - return new MistralAiApi(resoledBaseUrl, resolvedApiKey, restClientBuilder); + return new MistralAiApi(resoledBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler); } @Bean diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java index a6d2a56260e..b9899440641 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfiguration.java @@ -17,6 +17,7 @@ import java.util.List; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; @@ -35,15 +36,17 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; +import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; /** * @author Christian Tzolov */ -@AutoConfiguration(after = { RestClientAutoConfiguration.class }) +@AutoConfiguration(after = { RestClientAutoConfiguration.class, SpringAiRetryAutoConfiguration.class }) @ConditionalOnClass(OpenAiApi.class) @EnableConfigurationProperties({ OpenAiConnectionProperties.class, OpenAiChatProperties.class, OpenAiEmbeddingProperties.class, OpenAiImageProperties.class, OpenAiAudioTranscriptionProperties.class }) @@ -55,16 +58,17 @@ public class OpenAiAutoConfiguration { matchIfMissing = true) public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProperties, OpenAiChatProperties chatProperties, RestClient.Builder restClientBuilder, - List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext) { + List toolFunctionCallbacks, FunctionCallbackContext functionCallbackContext, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { var openAiApi = openAiApi(chatProperties.getBaseUrl(), commonProperties.getBaseUrl(), - chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder); + chatProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); if (!CollectionUtils.isEmpty(toolFunctionCallbacks)) { chatProperties.getOptions().getFunctionCallbacks().addAll(toolFunctionCallbacks); } - return new OpenAiChatClient(openAiApi, chatProperties.getOptions(), functionCallbackContext); + return new OpenAiChatClient(openAiApi, chatProperties.getOptions(), functionCallbackContext, retryTemplate); } @Bean @@ -72,17 +76,18 @@ public OpenAiChatClient openAiChatClient(OpenAiConnectionProperties commonProper @ConditionalOnProperty(prefix = OpenAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public EmbeddingClient openAiEmbeddingClient(OpenAiConnectionProperties commonProperties, - OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder) { + OpenAiEmbeddingProperties embeddingProperties, RestClient.Builder restClientBuilder, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { var openAiApi = openAiApi(embeddingProperties.getBaseUrl(), commonProperties.getBaseUrl(), - embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder); + embeddingProperties.getApiKey(), commonProperties.getApiKey(), restClientBuilder, responseErrorHandler); return new OpenAiEmbeddingClient(openAiApi, embeddingProperties.getMetadataMode(), - embeddingProperties.getOptions()); + embeddingProperties.getOptions(), retryTemplate); } private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String commonApiKey, - RestClient.Builder restClientBuilder) { + RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) { String resolvedBaseUrl = StringUtils.hasText(baseUrl) ? baseUrl : commonBaseUrl; Assert.hasText(resolvedBaseUrl, "OpenAI base URL must be set"); @@ -90,7 +95,7 @@ private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, String resolvedApiKey = StringUtils.hasText(apiKey) ? apiKey : commonApiKey; Assert.hasText(resolvedApiKey, "OpenAI API key must be set"); - return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder); + return new OpenAiApi(resolvedBaseUrl, resolvedApiKey, restClientBuilder, responseErrorHandler); } @Bean @@ -98,7 +103,9 @@ private OpenAiApi openAiApi(String baseUrl, String commonBaseUrl, String apiKey, @ConditionalOnProperty(prefix = OpenAiImageProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProperties, - OpenAiImageProperties imageProperties, RestClient.Builder restClientBuilder) { + OpenAiImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { + String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey() : commonProperties.getApiKey(); @@ -108,15 +115,16 @@ public OpenAiImageClient openAiImageClient(OpenAiConnectionProperties commonProp Assert.hasText(apiKey, "OpenAI API key must be set"); Assert.hasText(baseUrl, "OpenAI base URL must be set"); - var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder); + var openAiImageApi = new OpenAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler); - return new OpenAiImageClient(openAiImageApi).withDefaultOptions(imageProperties.getOptions()); + return new OpenAiImageClient(openAiImageApi, imageProperties.getOptions(), retryTemplate); } @Bean @ConditionalOnMissingBean public OpenAiAudioTranscriptionClient openAiAudioTranscriptionClient(OpenAiConnectionProperties commonProperties, - OpenAiAudioTranscriptionProperties transcriptionProperties) { + OpenAiAudioTranscriptionProperties transcriptionProperties, RetryTemplate retryTemplate, + ResponseErrorHandler responseErrorHandler) { String apiKey = StringUtils.hasText(transcriptionProperties.getApiKey()) ? transcriptionProperties.getApiKey() : commonProperties.getApiKey(); @@ -127,10 +135,10 @@ public OpenAiAudioTranscriptionClient openAiAudioTranscriptionClient(OpenAiConne Assert.hasText(apiKey, "OpenAI API key must be set"); Assert.hasText(baseUrl, "OpenAI base URL must be set"); - var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder()); + var openAiAudioApi = new OpenAiAudioApi(baseUrl, apiKey, RestClient.builder(), responseErrorHandler); OpenAiAudioTranscriptionClient openAiChatClient = new OpenAiAudioTranscriptionClient(openAiAudioApi, - transcriptionProperties.getOptions()); + transcriptionProperties.getOptions(), retryTemplate); return openAiChatClient; } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java new file mode 100644 index 00000000000..4daf371fcec --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfiguration.java @@ -0,0 +1,105 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.retry; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.retry.NonTransientAiException; +import org.springframework.ai.retry.TransientAiException; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.http.client.ClientHttpResponse; +import org.springframework.lang.NonNull; +import org.springframework.retry.RetryCallback; +import org.springframework.retry.RetryContext; +import org.springframework.retry.RetryListener; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StreamUtils; +import org.springframework.web.client.ResponseErrorHandler; + +/** + * @author Christian Tzolov + */ +@AutoConfiguration +@ConditionalOnClass(RetryTemplate.class) +@EnableConfigurationProperties({ SpringAiRetryProperties.class }) +public class SpringAiRetryAutoConfiguration { + + private static final Logger logger = LoggerFactory.getLogger(SpringAiRetryAutoConfiguration.class); + + @Bean + @ConditionalOnMissingBean + public RetryTemplate retryTemplate(SpringAiRetryProperties properties) { + return RetryTemplate.builder() + .maxAttempts(properties.getMaxAttempts()) + .retryOn(NonTransientAiException.class) + .exponentialBackoff(properties.getBackoff().getInitialInterval(), properties.getBackoff().getMultiplier(), + properties.getBackoff().getMaxInterval()) + .withListener(new RetryListener() { + @Override + public void onError(RetryContext context, + RetryCallback callback, Throwable throwable) { + logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable); + }; + }) + .build(); + } + + @Bean + @ConditionalOnMissingBean + public ResponseErrorHandler responseErrorHandler(SpringAiRetryProperties properties) { + + return new ResponseErrorHandler() { + + @Override + public boolean hasError(@NonNull ClientHttpResponse response) throws IOException { + return response.getStatusCode().isError(); + } + + @Override + public void handleError(@NonNull ClientHttpResponse response) throws IOException { + if (response.getStatusCode().isError()) { + String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8); + String message = String.format("%s - %s", response.getStatusCode().value(), error); + /** + * Thrown on 4xx client errors, such as 401 - Incorrect API key + * provided, 401 - You must be a member of an organization to use the + * API, 429 - Rate limit reached for requests, 429 - You exceeded your + * current quota , please check your plan and billing details. + */ + if (properties.isNoRetryOnHttpClientErrors() && response.getStatusCode().is4xxClientError()) { + throw new NonTransientAiException(message); + } + // Explicitly configured non-transient codes + if (!CollectionUtils.isEmpty(properties.getNoRetryOnHttpCodes()) + && properties.getNoRetryOnHttpCodes().contains(response.getStatusCode().value())) { + throw new NonTransientAiException(message); + } + throw new TransientAiException(message); + } + } + }; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java new file mode 100644 index 00000000000..f560d0ebf58 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryProperties.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.retry; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +/** + * @author Christian Tzolov + */ +@ConfigurationProperties(SpringAiRetryProperties.CONFIG_PREFIX) +public class SpringAiRetryProperties { + + public static final String CONFIG_PREFIX = "spring.ai.retry"; + + /** + * Maximum number of retry attempts. + */ + private int maxAttempts = 10; + + /** + * Exponential Backoff properties. + */ + @NestedConfigurationProperty + private Backoff backoff = new Backoff(); + + private boolean noRetryOnHttpClientErrors = true; + + private List noRetryOnHttpCodes = new ArrayList<>(); + + /** + * Exponential Backoff properties. + */ + public static class Backoff { + + /** + * Initial sleep duration. + */ + private Duration initialInterval = Duration.ofMillis(2000); + + /** + * Backoff interval multiplier. + */ + private int multiplier = 5; + + /** + * Maximum backoff duration. + */ + private Duration maxInterval = Duration.ofMillis(3 * 60000); + + public Duration getInitialInterval() { + return initialInterval; + } + + public void setInitialInterval(Duration initialInterval) { + this.initialInterval = initialInterval; + } + + public int getMultiplier() { + return multiplier; + } + + public void setMultiplier(int multiplier) { + this.multiplier = multiplier; + } + + public Duration getMaxInterval() { + return maxInterval; + } + + public void setMaxInterval(Duration maxInterval) { + this.maxInterval = maxInterval; + } + + } + + public int getMaxAttempts() { + return this.maxAttempts; + } + + public void setMaxAttempts(int maxAttempts) { + this.maxAttempts = maxAttempts; + } + + public Backoff getBackoff() { + return this.backoff; + } + + public List getNoRetryOnHttpCodes() { + return this.noRetryOnHttpCodes; + } + + public void setNoRetryOnHttpCodes(List noRetryOnHttpCodes) { + this.noRetryOnHttpCodes = noRetryOnHttpCodes; + } + + public boolean isNoRetryOnHttpClientErrors() { + return this.noRetryOnHttpClientErrors; + } + + public void setNoRetryOnHttpClientErrors(boolean noRetryOnHttpClientErrors) { + this.noRetryOnHttpClientErrors = noRetryOnHttpClientErrors; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index f2ced6509de..172f4785d0b 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -22,4 +22,4 @@ org.springframework.ai.autoconfigure.vectorstore.azure.AzureVectorStoreAutoConfi org.springframework.ai.autoconfigure.vectorstore.weaviate.WeaviateVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.neo4j.Neo4jVectorStoreAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.qdrant.QdrantVectorStoreAutoConfiguration - +org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java index c3e56639c0d..25816e8c0e4 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiAutoConfigurationIT.java @@ -24,6 +24,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import reactor.core.publisher.Flux; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -47,7 +48,8 @@ public class MistralAiAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY")) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)); @Test void generate() { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java index 6963855c837..ee4cf9aa7ba 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/MistralAiPropertiesTests.java @@ -16,6 +16,8 @@ package org.springframework.ai.autoconfigure.mistralai; import org.junit.jupiter.api.Test; + +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -33,8 +35,8 @@ public void embeddingProperties() { new ApplicationContextRunner() .withPropertyValues("spring.ai.mistralai.base-url=TEST_BASE_URL", "spring.ai.mistralai.api-key=abc123", "spring.ai.mistralai.embedding.options.model=MODEL_XYZ") - .withConfiguration( - AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) .run(context -> { var embeddingProperties = context.getBean(MistralAiEmbeddingProperties.class); var connectionProperties = context.getBean(MistralAiCommonProperties.class); @@ -55,8 +57,8 @@ public void embeddingOverrideConnectionProperties() { new ApplicationContextRunner().withPropertyValues("spring.ai.mistralai.base-url=TEST_BASE_URL", "spring.ai.mistralai.api-key=abc123", "spring.ai.mistralai.embedding.base-url=TEST_BASE_URL2", "spring.ai.mistralai.embedding.api-key=456", "spring.ai.mistralai.embedding.options.model=MODEL_XYZ") - .withConfiguration( - AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) .run(context -> { var embeddingProperties = context.getBean(MistralAiEmbeddingProperties.class); var connectionProperties = context.getBean(MistralAiCommonProperties.class); @@ -79,8 +81,8 @@ public void embeddingOptionsTest() { "spring.ai.mistralai.embedding.options.model=MODEL_XYZ", "spring.ai.mistralai.embedding.options.encodingFormat=MyEncodingFormat") - .withConfiguration( - AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) .run(context -> { var connectionProperties = context.getBean(MistralAiCommonProperties.class); var embeddingProperties = context.getBean(MistralAiEmbeddingProperties.class); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java index 0b7972bdfbc..d4259f3ee38 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanIT.java @@ -26,6 +26,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -48,7 +49,8 @@ class PaymentStatusBeanIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY")) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)) .withUserConfiguration(Config.class); @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java index 33a5e327c6c..15796f921b2 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusBeanOpenAiIT.java @@ -26,6 +26,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -55,7 +56,8 @@ class PaymentStatusBeanOpenAiIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY"), "spring.ai.openai.chat.base-url=https://api.mistral.ai") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .withUserConfiguration(Config.class); @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java index d7a46ece4ee..58bd9e84cc6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/PaymentStatusPromptIT.java @@ -26,6 +26,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -46,7 +47,8 @@ public class PaymentStatusPromptIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mistralai.apiKey=" + System.getenv("MISTRAL_AI_API_KEY")) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)); public record Transaction(@JsonProperty(required = true, value = "transaction_id") String id) { } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java index 915c378f008..6bf052f5697 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/mistralai/tool/WeatherServicePromptIT.java @@ -29,6 +29,7 @@ import org.springframework.ai.autoconfigure.mistralai.MistralAiAutoConfiguration; import org.springframework.ai.autoconfigure.mistralai.tool.WeatherServicePromptIT.MyWeatherService.Request; import org.springframework.ai.autoconfigure.mistralai.tool.WeatherServicePromptIT.MyWeatherService.Response; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -54,7 +55,8 @@ public class WeatherServicePromptIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.mistralai.api-key=" + System.getenv("MISTRAL_AI_API_KEY")) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, MistralAiAutoConfiguration.class)); @Test void promptFunctionCall() { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java index 2561279f91f..41bbfb8c7b1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiAutoConfigurationIT.java @@ -31,6 +31,7 @@ import org.springframework.core.io.Resource; import reactor.core.publisher.Flux; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.openai.OpenAiAudioTranscriptionClient; @@ -49,7 +50,8 @@ public class OpenAiAutoConfigurationIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)); @Test void generate() { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java index 0eb67dd6c7b..e14a6a562de 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiPropertiesTests.java @@ -19,6 +19,7 @@ import org.skyscreamer.jsonassert.JSONAssert; import org.skyscreamer.jsonassert.JSONCompareMode; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.openai.OpenAiEmbeddingClient; import org.springframework.ai.openai.OpenAiImageClient; @@ -52,7 +53,8 @@ public void chatProperties() { "spring.ai.openai.chat.options.model=MODEL_XYZ", "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(OpenAiChatProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -78,7 +80,8 @@ public void transcriptionProperties() { "spring.ai.openai.audio.transcription.options.model=MODEL_XYZ", "spring.ai.openai.audio.transcription.options.temperature=0.55") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var transcriptionProperties = context.getBean(OpenAiAudioTranscriptionProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -106,7 +109,8 @@ public void chatOverrideConnectionProperties() { "spring.ai.openai.chat.options.model=MODEL_XYZ", "spring.ai.openai.chat.options.temperature=0.55") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(OpenAiChatProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -134,7 +138,8 @@ public void transcriptionOverrideConnectionProperties() { "spring.ai.openai.audio.transcription.options.model=MODEL_XYZ", "spring.ai.openai.audio.transcription.options.temperature=0.55") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var transcriptionProperties = context.getBean(OpenAiAudioTranscriptionProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -159,7 +164,8 @@ public void embeddingProperties() { "spring.ai.openai.api-key=abc123", "spring.ai.openai.embedding.options.model=MODEL_XYZ") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -185,7 +191,8 @@ public void embeddingOverrideConnectionProperties() { "spring.ai.openai.embedding.api-key=456", "spring.ai.openai.embedding.options.model=MODEL_XYZ") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -209,7 +216,8 @@ public void imageProperties() { "spring.ai.openai.image.options.model=MODEL_XYZ", "spring.ai.openai.image.options.n=3") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var imageProperties = context.getBean(OpenAiImageProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -236,7 +244,8 @@ public void imageOverrideConnectionProperties() { "spring.ai.openai.image.options.model=MODEL_XYZ", "spring.ai.openai.image.options.n=3") // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var imageProperties = context.getBean(OpenAiImageProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -304,7 +313,8 @@ public void chatOptionsTest() { "spring.ai.openai.chat.options.user=userXYZ" ) // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var chatProperties = context.getBean(OpenAiChatProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -357,7 +367,8 @@ public void transcriptionOptionsTest() { "spring.ai.openai.audio.transcription.options.temperature=0.55" ) // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var transcriptionProperties = context.getBean(OpenAiAudioTranscriptionProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -390,7 +401,8 @@ public void embeddingOptionsTest() { "spring.ai.openai.embedding.options.user=userXYZ" ) // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var connectionProperties = context.getBean(OpenAiConnectionProperties.class); var embeddingProperties = context.getBean(OpenAiEmbeddingProperties.class); @@ -422,7 +434,8 @@ public void imageOptionsTest() { "spring.ai.openai.image.options.user=userXYZ" ) // @formatter:on - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { var imageProperties = context.getBean(OpenAiImageProperties.class); var connectionProperties = context.getBean(OpenAiConnectionProperties.class); @@ -448,7 +461,8 @@ void embeddingActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.embedding.enabled=false") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiEmbeddingClient.class)).isEmpty(); @@ -456,7 +470,8 @@ void embeddingActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiEmbeddingClient.class)).isNotEmpty(); @@ -465,7 +480,8 @@ void embeddingActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.embedding.enabled=true") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiEmbeddingProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiEmbeddingClient.class)).isNotEmpty(); @@ -477,7 +493,8 @@ void chatActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.chat.enabled=false") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiChatClient.class)).isEmpty(); @@ -485,7 +502,8 @@ void chatActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiChatClient.class)).isNotEmpty(); @@ -494,7 +512,8 @@ void chatActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.chat.enabled=true") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiChatProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiChatClient.class)).isNotEmpty(); @@ -507,7 +526,8 @@ void imageActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.image.enabled=false") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiImageProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiImageClient.class)).isEmpty(); @@ -515,7 +535,8 @@ void imageActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiImageProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiImageClient.class)).isNotEmpty(); @@ -524,7 +545,8 @@ void imageActivation() { new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.api-key=API_KEY", "spring.ai.openai.base-url=TEST_BASE_URL", "spring.ai.openai.image.enabled=true") - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .run(context -> { assertThat(context.getBeansOfType(OpenAiImageProperties.class)).isNotEmpty(); assertThat(context.getBeansOfType(OpenAiImageClient.class)).isNotEmpty(); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java index a5c0574b448..09e17abe5e5 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackInPromptIT.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -42,7 +43,8 @@ public class FunctionCallbackInPromptIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)); + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)); @Test void functionCallTest() { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java index 9aa664925e2..51bd8721dd8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWithPlainFunctionBeanIT.java @@ -24,6 +24,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -47,7 +48,8 @@ class FunctionCallbackWithPlainFunctionBeanIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .withUserConfiguration(Config.class); @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java index df1f6a33276..9a4e70072dd 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/tool/FunctionCallbackWrapperIT.java @@ -23,6 +23,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration; +import org.springframework.ai.autoconfigure.retry.SpringAiRetryAutoConfiguration; import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -45,7 +46,8 @@ public class FunctionCallbackWrapperIT { private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withPropertyValues("spring.ai.openai.apiKey=" + System.getenv("OPENAI_API_KEY")) - .withConfiguration(AutoConfigurations.of(RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, + RestClientAutoConfiguration.class, OpenAiAutoConfiguration.class)) .withUserConfiguration(Config.class); @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java new file mode 100644 index 00000000000..f15e605fc50 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryAutoConfigurationIT.java @@ -0,0 +1,44 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.retry; + +import org.junit.jupiter.api.Test; + +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.retry.support.RetryTemplate; +import org.springframework.web.client.ResponseErrorHandler; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +public class SpringAiRetryAutoConfigurationIT { + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner().withConfiguration( + AutoConfigurations.of(SpringAiRetryAutoConfiguration.class, RestClientAutoConfiguration.class)); + + @Test + void testRetryAutoConfiguration() { + this.contextRunner.run((context) -> { + assertThat(context).hasSingleBean(RetryTemplate.class); + assertThat(context).hasSingleBean(ResponseErrorHandler.class); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java new file mode 100644 index 00000000000..ebf549276d7 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/retry/SpringAiRetryPropertiesTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.retry; + +import org.junit.jupiter.api.Test; + +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Unit Tests for {@link SpringAiRetryProperties}. + * + * @author Christian Tzolov + */ +public class SpringAiRetryPropertiesTests { + + @Test + public void retryDefaultProperties() { + + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class)) + .run(context -> { + var retryProperties = context.getBean(SpringAiRetryProperties.class); + + assertThat(retryProperties.getMaxAttempts()).isEqualTo(10); + assertThat(retryProperties.isNoRetryOnHttpClientErrors()).isTrue(); + assertThat(retryProperties.getNoRetryOnHttpCodes()).isEmpty(); + assertThat(retryProperties.getBackoff().getInitialInterval().toMillis()).isEqualTo(2000); + assertThat(retryProperties.getBackoff().getMultiplier()).isEqualTo(5); + assertThat(retryProperties.getBackoff().getMaxInterval().toMillis()).isEqualTo(3 * 60000); + }); + } + + @Test + public void retryCustomProperties() { + + new ApplicationContextRunner().withPropertyValues( + // @formatter:off + "spring.ai.retry.max-attempts=100", + "spring.ai.retry.no-retry-on-http-client-errors=false", + "spring.ai.retry.no-retry-on-http-codes=404,500", + "spring.ai.retry.backoff.initial-interval=1000", + "spring.ai.retry.backoff.multiplier=2", + "spring.ai.retry.backoff.max-interval=60000" ) + // @formatter:on + .withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class)) + .run(context -> { + var retryProperties = context.getBean(SpringAiRetryProperties.class); + + assertThat(retryProperties.getMaxAttempts()).isEqualTo(100); + assertThat(retryProperties.isNoRetryOnHttpClientErrors()).isFalse(); + assertThat(retryProperties.getNoRetryOnHttpCodes()).containsExactly(404, 500); + assertThat(retryProperties.getBackoff().getInitialInterval().toMillis()).isEqualTo(1000); + assertThat(retryProperties.getBackoff().getMultiplier()).isEqualTo(2); + assertThat(retryProperties.getBackoff().getMaxInterval().toMillis()).isEqualTo(60000); + }); + } + +}