Skip to content

Refactor and centralize Retry logic #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions models/spring-ai-mistral-ai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,13 @@
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
<version>${spring-framework.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.retry</groupId>
<artifactId>spring-retry</artifactId>
<version>2.0.4</version>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-retry</artifactId>
<version>${project.parent.version}</version>
</dependency>

<!-- Spring Framework -->
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webflux</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context-support</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package org.springframework.ai.mistralai;

import java.time.Duration;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -41,10 +40,8 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
Expand All @@ -70,17 +67,7 @@ public class MistralAiChatClient extends
*/
private final MistralAiApi mistralAiApi;

private final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(MistralAiApi.MistralAiApiException.class)
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() {
public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
log.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
};
})
.build();
private final RetryTemplate retryTemplate;

public MistralAiChatClient(MistralAiApi mistralAiApi) {
this(mistralAiApi,
Expand All @@ -93,46 +80,50 @@ public MistralAiChatClient(MistralAiApi mistralAiApi) {
}

public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options) {
this(mistralAiApi, options, null);
this(mistralAiApi, options, null, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public MistralAiChatClient(MistralAiApi mistralAiApi, MistralAiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
FunctionCallbackContext functionCallbackContext, RetryTemplate retryTemplate) {
super(functionCallbackContext);
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
Assert.notNull(options, "Options must not be null");
Assert.notNull(retryTemplate, "RetryTemplate must not be null");
this.mistralAiApi = mistralAiApi;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
public ChatResponse call(Prompt prompt) {
// return retryTemplate.execute(ctx -> {
var request = createRequest(prompt, false);

// var completionEntity = this.mistralAiApi.chatCompletionEntity(request);
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
return retryTemplate.execute(ctx -> {

var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
log.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}
ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);

List<Generation> generations = chatCompletion.choices()
.stream()
.map(choice -> new Generation(choice.message().content(), Map.of("role", choice.message().role().name()))
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
.toList();
var chatCompletion = completionEntity.getBody();
if (chatCompletion == null) {
log.warn("No chat completion returned for prompt: {}", prompt);
return new ChatResponse(List.of());
}

List<Generation> generations = chatCompletion.choices()
.stream()
.map(choice -> new Generation(choice.message().content(),
Map.of("role", choice.message().role().name()))
.withGenerationMetadata(ChatGenerationMetadata.from(choice.finishReason().name(), null)))
.toList();

return new ChatResponse(generations);
// });
return new ChatResponse(generations);
});
}

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
var request = createRequest(prompt, true);

return retryTemplate.execute(ctx -> {
var request = createRequest(prompt, true);

var completionChunks = this.mistralAiApi.chatCompletionStream(request);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,25 @@
*/
package org.springframework.ai.mistralai;

import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.*;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.mistralai.api.MistralAiApi.MistralAiApiException;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

import java.time.Duration;
import java.util.List;

/**
* @author Ricken Bazolo
* @since 0.8.1
Expand All @@ -46,40 +48,33 @@ public class MistralAiEmbeddingClient extends AbstractEmbeddingClient {

private final MistralAiApi mistralAiApi;

private final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(MistralAiApiException.class)
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() {
public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
log.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
};
})
.build();
private final RetryTemplate retryTemplate;

public MistralAiEmbeddingClient(MistralAiApi mistralAiApi) {
this(mistralAiApi, MetadataMode.EMBED);
}

public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode) {
this(mistralAiApi, metadataMode,
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build());
MistralAiEmbeddingOptions.builder().withModel(MistralAiApi.EmbeddingModel.EMBED.getValue()).build(),
RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MistralAiEmbeddingOptions options) {
this(mistralAiApi, MetadataMode.EMBED, options);
this(mistralAiApi, MetadataMode.EMBED, options, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public MistralAiEmbeddingClient(MistralAiApi mistralAiApi, MetadataMode metadataMode,
MistralAiEmbeddingOptions options) {
MistralAiEmbeddingOptions options, RetryTemplate retryTemplate) {
Assert.notNull(mistralAiApi, "MistralAiApi must not be null");
Assert.notNull(metadataMode, "metadataMode must not be null");
Assert.notNull(options, "options must not be null");
Assert.notNull(retryTemplate, "retryTemplate must not be null");

this.mistralAiApi = mistralAiApi;
this.metadataMode = metadataMode;
this.defaultOptions = options;
this.retryTemplate = retryTemplate;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
package org.springframework.ai.mistralai.api;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
Expand All @@ -25,22 +23,18 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.boot.context.properties.bind.ConstructorBinding;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
Expand Down Expand Up @@ -70,8 +64,6 @@ public class MistralAiApi {

private WebClient webClient;

private final ObjectMapper objectMapper;

/**
* Create a new client api with DEFAULT_BASE_URL
* @param mistralAiApiKey Mistral api Key.
Expand All @@ -86,75 +78,30 @@ public MistralAiApi(String mistralAiApiKey) {
* @param mistralAiApiKey Mistral api Key.
*/
public MistralAiApi(String baseUrl, String mistralAiApiKey) {
this(baseUrl, mistralAiApiKey, RestClient.builder());
this(baseUrl, mistralAiApiKey, RestClient.builder(), RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
}

/**
* Create a new client api.
* @param baseUrl api base URL.
* @param mistralAiApiKey Mistral api Key.
* @param restClientBuilder RestClient builder.
* @param responseErrorHandler Response error handler.
*/
public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder) {

this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
public MistralAiApi(String baseUrl, String mistralAiApiKey, RestClient.Builder restClientBuilder,
ResponseErrorHandler responseErrorHandler) {

Consumer<HttpHeaders> jsonContentHeaders = headers -> {
headers.setBearerAuth(mistralAiApiKey);
headers.setContentType(MediaType.APPLICATION_JSON);
};

var responseErrorHandler = new ResponseErrorHandler() {

@Override
public boolean hasError(@NonNull ClientHttpResponse response) throws IOException {
return response.getStatusCode().isError();
}

@Override
public void handleError(@NonNull ClientHttpResponse response) throws IOException {
if (response.getStatusCode().isError()) {
String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
String message = String.format("%s - %s", response.getStatusCode().value(), error);
if (response.getStatusCode().is4xxClientError()) {
throw new MistralAiApiClientErrorException(message);
}
throw new MistralAiApiException(message);
}
}
};

this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(jsonContentHeaders)
.defaultStatusHandler(responseErrorHandler)
.build();
this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
}

public static class MistralAiApiException extends RuntimeException {

public MistralAiApiException(String message) {
super(message);
}

public MistralAiApiException(String message, Throwable t) {
super(message, t);
}

}

/**
* Thrown on 4xx client errors, such as 401 - Incorrect API key provided, 401 - You
* must be a member of an organization to use the API, 429 - Rate limit reached for
* requests, 429 - You exceeded your current quota , please check your plan and
* billing details.
*/
public static class MistralAiApiClientErrorException extends RuntimeException {

public MistralAiApiClientErrorException(String message) {
super(message);
}

this.webClient = WebClient.builder().baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
}

/**
Expand Down Expand Up @@ -594,7 +541,7 @@ public enum ChatCompletionFinishReason {
// anticipation of future changes. Based on:
// https://github.com/mistralai/client-python/blob/main/src/mistralai/models/chat_completion.py
@JsonProperty("error") ERROR,

@JsonProperty("tool_calls") TOOL_CALLS
// @formatter:on

Expand Down
Loading