Skip to content

Commit f89530c

Browse files
apappascsmarkpollack
authored andcommitted
feat(ollama): add retry template integration to OllamaChatModel
* Update tests that are supposed to fail to not use retry * Upgrade ot use Ollama 0.6.7 Signed-off-by: Alexandros Pappas <alexandros.pappas@yiluhub.com>
1 parent 50f8fa7 commit f89530c

File tree

14 files changed

+340
-11
lines changed

14 files changed

+340
-11
lines changed

auto-configurations/models/spring-ai-autoconfigure-model-ollama/pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
</dependency>
3535

3636
<!-- Spring AI auto configurations -->
37+
38+
<dependency>
39+
<groupId>org.springframework.ai</groupId>
40+
<artifactId>spring-ai-autoconfigure-retry</artifactId>
41+
<version>${project.parent.version}</version>
42+
<optional>true</optional>
43+
</dependency>
44+
3745
<dependency>
3846
<groupId>org.springframework.ai</groupId>
3947
<artifactId>spring-ai-autoconfigure-model-tool</artifactId>

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaChatAutoConfigurationTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21+
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
2122
import org.springframework.boot.autoconfigure.AutoConfigurations;
2223
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
2324
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -41,8 +42,9 @@ public void propertiesTest() {
4142
"spring.ai.ollama.chat.options.topP=0.56",
4243
"spring.ai.ollama.chat.options.topK=123")
4344
// @formatter:on
44-
.withConfiguration(
45-
AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
45+
46+
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
47+
RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
4648
.run(context -> {
4749
var chatProperties = context.getBean(OllamaChatProperties.class);
4850
var connectionProperties = context.getBean(OllamaConnectionProperties.class);

auto-configurations/models/spring-ai-autoconfigure-model-ollama/src/test/java/org/springframework/ai/model/ollama/autoconfigure/OllamaEmbeddingAutoConfigurationTests.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.junit.jupiter.api.Test;
2020

21+
import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration;
2122
import org.springframework.boot.autoconfigure.AutoConfigurations;
2223
import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration;
2324
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
@@ -26,6 +27,7 @@
2627

2728
/**
2829
* @author Christian Tzolov
30+
* @author Alexandros Pappas
2931
* @since 0.8.0
3032
*/
3133
public class OllamaEmbeddingAutoConfigurationTests {
@@ -41,8 +43,9 @@ public void propertiesTest() {
4143
"spring.ai.ollama.embedding.options.topK=13"
4244
// @formatter:on
4345
)
44-
.withConfiguration(
45-
AutoConfigurations.of(RestClientAutoConfiguration.class, OllamaEmbeddingAutoConfiguration.class))
46+
47+
.withConfiguration(AutoConfigurations.of(SpringAiRetryAutoConfiguration.class,
48+
RestClientAutoConfiguration.class, OllamaChatAutoConfiguration.class))
4649
.run(context -> {
4750
var embeddingProperties = context.getBean(OllamaEmbeddingProperties.class);
4851
var connectionProperties = context.getBean(OllamaConnectionProperties.class);

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,13 @@
6565
import org.springframework.ai.ollama.management.ModelManagementOptions;
6666
import org.springframework.ai.ollama.management.OllamaModelManager;
6767
import org.springframework.ai.ollama.management.PullModelStrategy;
68+
6869
import org.springframework.ai.tool.definition.ToolDefinition;
6970
import org.springframework.ai.util.json.JsonParser;
71+
72+
import org.springframework.ai.retry.RetryUtils;
73+
import org.springframework.retry.support.RetryTemplate;
74+
7075
import org.springframework.util.Assert;
7176
import org.springframework.util.CollectionUtils;
7277
import org.springframework.util.StringUtils;
@@ -129,27 +134,32 @@ public class OllamaChatModel implements ChatModel {
129134

130135
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
131136

137+
private final RetryTemplate retryTemplate;
138+
132139
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
133140
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
134141
this(ollamaApi, defaultOptions, toolCallingManager, observationRegistry, modelManagementOptions,
135-
new DefaultToolExecutionEligibilityPredicate());
142+
new DefaultToolExecutionEligibilityPredicate(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
136143
}
137144

138145
public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
139146
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions,
140-
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
147+
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate, RetryTemplate retryTemplate) {
148+
141149
Assert.notNull(ollamaApi, "ollamaApi must not be null");
142150
Assert.notNull(defaultOptions, "defaultOptions must not be null");
143151
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
144152
Assert.notNull(observationRegistry, "observationRegistry must not be null");
145153
Assert.notNull(modelManagementOptions, "modelManagementOptions must not be null");
146154
Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate must not be null");
155+
Assert.notNull(retryTemplate, "retryTemplate must not be null");
147156
this.chatApi = ollamaApi;
148157
this.defaultOptions = defaultOptions;
149158
this.toolCallingManager = toolCallingManager;
150159
this.observationRegistry = observationRegistry;
151160
this.modelManager = new OllamaModelManager(this.chatApi, modelManagementOptions);
152161
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
162+
this.retryTemplate = retryTemplate;
153163
initializeModel(defaultOptions.getModel(), modelManagementOptions.pullModelStrategy());
154164
}
155165

@@ -237,7 +247,7 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
237247
this.observationRegistry)
238248
.observe(() -> {
239249

240-
OllamaApi.ChatResponse ollamaResponse = this.chatApi.chat(request);
250+
OllamaApi.ChatResponse ollamaResponse = this.retryTemplate.execute(ctx -> this.chatApi.chat(request));
241251

242252
List<AssistantMessage.ToolCall> toolCalls = ollamaResponse.message().toolCalls() == null ? List.of()
243253
: ollamaResponse.message()
@@ -540,6 +550,8 @@ public static final class Builder {
540550

541551
private ModelManagementOptions modelManagementOptions = ModelManagementOptions.defaults();
542552

553+
private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
554+
543555
private Builder() {
544556
}
545557

@@ -574,13 +586,20 @@ public Builder modelManagementOptions(ModelManagementOptions modelManagementOpti
574586
return this;
575587
}
576588

589+
public Builder retryTemplate(RetryTemplate retryTemplate) {
590+
this.retryTemplate = retryTemplate;
591+
return this;
592+
}
593+
577594
public OllamaChatModel build() {
578595
if (this.toolCallingManager != null) {
579596
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, this.toolCallingManager,
580-
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
597+
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate,
598+
this.retryTemplate);
581599
}
582600
return new OllamaChatModel(this.ollamaApi, this.defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
583-
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate);
601+
this.observationRegistry, this.modelManagementOptions, this.toolExecutionEligibilityPredicate,
602+
this.retryTemplate);
584603
}
585604

586605
}

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaApi.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
* @author Christian Tzolov
5252
* @author Thomas Vitale
5353
* @author Jonghoon Park
54+
* @author Alexandros Pappas
5455
* @since 0.8.0
5556
*/
5657
// @formatter:off
@@ -64,6 +65,9 @@ public static Builder builder() {
6465

6566
private static final Log logger = LogFactory.getLog(OllamaApi.class);
6667

68+
69+
private static final String DEFAULT_BASE_URL = "http://localhost:11434";
70+
6771
private final RestClient restClient;
6872

6973
private final WebClient webClient;
@@ -77,18 +81,21 @@ public static Builder builder() {
7781
*/
7882
private OllamaApi(String baseUrl, RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder, ResponseErrorHandler responseErrorHandler) {
7983

84+
8085
Consumer<HttpHeaders> defaultHeaders = headers -> {
8186
headers.setContentType(MediaType.APPLICATION_JSON);
8287
headers.setAccept(List.of(MediaType.APPLICATION_JSON));
8388
};
8489

90+
8591
this.restClient = restClientBuilder
8692
.clone()
8793
.baseUrl(baseUrl)
8894
.defaultHeaders(defaultHeaders)
8995
.defaultStatusHandler(responseErrorHandler)
9096
.build();
9197

98+
9299
this.webClient = webClientBuilder
93100
.clone()
94101
.baseUrl(baseUrl)

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelFunctionCallingIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.ai.ollama.api.OllamaOptions;
3737
import org.springframework.ai.ollama.api.tool.MockWeatherService;
3838
import org.springframework.ai.tool.function.FunctionToolCallback;
39+
import org.springframework.ai.retry.RetryUtils;
3940
import org.springframework.beans.factory.annotation.Autowired;
4041
import org.springframework.boot.SpringBootConfiguration;
4142
import org.springframework.boot.test.context.SpringBootTest;
@@ -120,6 +121,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
120121
return OllamaChatModel.builder()
121122
.ollamaApi(ollamaApi)
122123
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
124+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
123125
.build();
124126
}
125127

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@
5252
import org.springframework.ai.ollama.management.ModelManagementOptions;
5353
import org.springframework.ai.ollama.management.OllamaModelManager;
5454
import org.springframework.ai.ollama.management.PullModelStrategy;
55+
5556
import org.springframework.ai.support.ToolCallbacks;
5657
import org.springframework.ai.tool.annotation.Tool;
58+
59+
import org.springframework.ai.retry.RetryUtils;
60+
5761
import org.springframework.beans.factory.annotation.Autowired;
5862
import org.springframework.boot.SpringBootConfiguration;
5963
import org.springframework.boot.test.context.SpringBootTest;
@@ -371,6 +375,7 @@ public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
371375
.pullModelStrategy(PullModelStrategy.WHEN_MISSING)
372376
.additionalModels(List.of(ADDITIONAL_MODEL))
373377
.build())
378+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
374379
.build();
375380
}
376381

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelMultimodalIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.ai.ollama;
1818

19+
import java.time.Duration;
1920
import java.util.List;
2021

2122
import org.junit.jupiter.api.Test;
@@ -27,11 +28,17 @@
2728
import org.springframework.ai.content.Media;
2829
import org.springframework.ai.ollama.api.OllamaApi;
2930
import org.springframework.ai.ollama.api.OllamaOptions;
31+
import org.springframework.ai.retry.RetryUtils;
32+
import org.springframework.ai.retry.TransientAiException;
3033
import org.springframework.beans.factory.annotation.Autowired;
3134
import org.springframework.boot.SpringBootConfiguration;
3235
import org.springframework.boot.test.context.SpringBootTest;
3336
import org.springframework.context.annotation.Bean;
3437
import org.springframework.core.io.ClassPathResource;
38+
import org.springframework.retry.RetryCallback;
39+
import org.springframework.retry.RetryContext;
40+
import org.springframework.retry.RetryListener;
41+
import org.springframework.retry.support.RetryTemplate;
3542
import org.springframework.util.MimeTypeUtils;
3643

3744
import static org.assertj.core.api.Assertions.assertThat;
@@ -86,9 +93,23 @@ public OllamaApi ollamaApi() {
8693

8794
@Bean
8895
public OllamaChatModel ollamaChat(OllamaApi ollamaApi) {
96+
RetryTemplate retryTemplate = RetryTemplate.builder()
97+
.maxAttempts(1)
98+
.retryOn(TransientAiException.class)
99+
.fixedBackoff(Duration.ofSeconds(1))
100+
.withListener(new RetryListener() {
101+
102+
@Override
103+
public <T extends Object, E extends Throwable> void onError(RetryContext context,
104+
RetryCallback<T, E> callback, Throwable throwable) {
105+
logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
106+
}
107+
})
108+
.build();
89109
return OllamaChatModel.builder()
90110
.ollamaApi(ollamaApi)
91111
.defaultOptions(OllamaOptions.builder().model(MODEL).temperature(0.9).build())
112+
.retryTemplate(retryTemplate)
92113
.build();
93114
}
94115

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelObservationIT.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.ollama.api.OllamaApi;
3535
import org.springframework.ai.ollama.api.OllamaModel;
3636
import org.springframework.ai.ollama.api.OllamaOptions;
37+
import org.springframework.ai.retry.RetryUtils;
3738
import org.springframework.beans.factory.annotation.Autowired;
3839
import org.springframework.boot.SpringBootConfiguration;
3940
import org.springframework.boot.test.context.SpringBootTest;
@@ -47,6 +48,7 @@
4748
* Integration tests for observation instrumentation in {@link OllamaChatModel}.
4849
*
4950
* @author Thomas Vitale
51+
* @author Alexandros Pappas
5052
*/
5153
@SpringBootTest(classes = OllamaChatModelObservationIT.Config.class)
5254
public class OllamaChatModelObservationIT extends BaseOllamaIT {
@@ -169,7 +171,11 @@ public OllamaApi openAiApi() {
169171

170172
@Bean
171173
public OllamaChatModel openAiChatModel(OllamaApi ollamaApi, TestObservationRegistry observationRegistry) {
172-
return OllamaChatModel.builder().ollamaApi(ollamaApi).observationRegistry(observationRegistry).build();
174+
return OllamaChatModel.builder()
175+
.ollamaApi(ollamaApi)
176+
.observationRegistry(observationRegistry)
177+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
178+
.build();
173179
}
174180

175181
}

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatModelTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.springframework.ai.ollama.api.OllamaModel;
3636
import org.springframework.ai.ollama.api.OllamaOptions;
3737
import org.springframework.ai.ollama.management.ModelManagementOptions;
38+
import org.springframework.ai.retry.RetryUtils;
3839

3940
import static org.assertj.core.api.Assertions.assertThat;
4041
import static org.junit.jupiter.api.Assertions.*;
@@ -82,6 +83,7 @@ void buildOllamaChatModel() {
8283
() -> OllamaChatModel.builder()
8384
.ollamaApi(this.ollamaApi)
8485
.defaultOptions(OllamaOptions.builder().model(OllamaModel.LLAMA2).build())
86+
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
8587
.modelManagementOptions(null)
8688
.build());
8789
assertEquals("modelManagementOptions must not be null", exception.getMessage());

0 commit comments

Comments
 (0)